diff --git a/.github/workflows/black_lint.yml b/.github/workflows/black_lint.yml new file mode 100644 index 0000000..b2cd244 --- /dev/null +++ b/.github/workflows/black_lint.yml @@ -0,0 +1,10 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: psf/black@stable \ No newline at end of file diff --git a/.gitignore b/.gitignore index 92ccf57..5143b08 100644 --- a/.gitignore +++ b/.gitignore @@ -63,3 +63,10 @@ cython_debug/ # Custome folder # testing test_data/ + + +.vscode/ + +.pytest_cache/ + +Binette_TestData/ \ No newline at end of file diff --git a/binette/__init__.py b/binette/__init__.py index 77139f6..976498a 100644 --- a/binette/__init__.py +++ b/binette/__init__.py @@ -1 +1 @@ -__version__ = '1.0.3' \ No newline at end of file +__version__ = "1.0.3" diff --git a/binette/bin_manager.py b/binette/bin_manager.py index 6891ac5..6cc30c7 100644 --- a/binette/bin_manager.py +++ b/binette/bin_manager.py @@ -8,10 +8,13 @@ import networkx as nx from typing import List, Dict, Iterable, Tuple, Set, Mapping + class Bin: counter = 0 - def __init__(self, contigs: Iterable[str], origin: str, name: str, is_original:bool=False) -> None: + def __init__( + self, contigs: Iterable[str], origin: str, name: str, is_original: bool = False + ) -> None: """ Initialize a Bin object. @@ -59,9 +62,11 @@ def __str__(self) -> str: :return: The string representation of the Bin object. """ - return f"Bin {self.id} from {';'.join(self.origin)} ({len(self.contigs)} contigs)" + return ( + f"Bin {self.id} from {';'.join(self.origin)} ({len(self.contigs)} contigs)" + ) - def overlaps_with(self, other: 'Bin') -> Set[str]: + def overlaps_with(self, other: "Bin") -> Set[str]: """ Find the contigs that overlap between this bin and another bin. @@ -83,7 +88,6 @@ def overlaps_with(self, other: 'Bin') -> Set[str]: # return Bin(contigs, origin, name) - def add_length(self, length: int) -> None: """ Add the length attribute to the Bin object if the provided length is a positive integer. @@ -107,9 +111,10 @@ def add_N50(self, n50: int) -> None: self.N50 = n50 else: raise ValueError("N50 should be a positive integer.") - - def add_quality(self, completeness: float, contamination: float, contamination_weight: float) -> None: + def add_quality( + self, completeness: float, contamination: float, contamination_weight: float + ) -> None: """ Set the quality attributes of the bin. @@ -121,7 +126,7 @@ def add_quality(self, completeness: float, contamination: float, contamination_w self.contamination = contamination self.score = completeness - contamination_weight * contamination - def intersection(self, *others: 'Bin') -> 'Bin': + def intersection(self, *others: "Bin") -> "Bin": """ Compute the intersection of the bin with other bins. @@ -135,7 +140,7 @@ def intersection(self, *others: 'Bin') -> 'Bin': return Bin(contigs, origin, name) - def difference(self, *others: 'Bin') -> 'Bin': + def difference(self, *others: "Bin") -> "Bin": """ Compute the difference between the bin and other bins. @@ -149,7 +154,7 @@ def difference(self, *others: 'Bin') -> 'Bin': return Bin(contigs, origin, name) - def union(self, *others: 'Bin') -> 'Bin': + def union(self, *others: "Bin") -> "Bin": """ Compute the union of the bin with other bins. @@ -162,14 +167,13 @@ def union(self, *others: 'Bin') -> 'Bin': origin = "union" return Bin(contigs, origin, name) - def is_complete_enough(self, min_completeness: float) -> bool: """ Determine if a bin is complete enough based on completeness threshold. :param min_completeness: The minimum completeness required for a bin. - + :raises ValueError: If completeness has not been set (is None). :return: True if the bin meets the min_completeness threshold; False otherwise. @@ -182,15 +186,16 @@ def is_complete_enough(self, min_completeness: float) -> bool: ) return self.completeness >= min_completeness - - def is_high_quality(self, min_completeness: float, max_contamination: float) -> bool: + def is_high_quality( + self, min_completeness: float, max_contamination: float + ) -> bool: """ Determine if a bin is considered high quality based on completeness and contamination thresholds. :param min_completeness: The minimum completeness required for a bin to be considered high quality. :param max_contamination: The maximum allowed contamination for a bin to be considered high quality. - + :raises ValueError: If either completeness or contamination has not been set (is None). :return: True if the bin meets the high quality criteria; False otherwise. @@ -201,11 +206,15 @@ def is_high_quality(self, min_completeness: float, max_contamination: float) -> "and therefore cannot be assessed for high quality." ) - return self.completeness >= min_completeness and self.contamination <= max_contamination - + return ( + self.completeness >= min_completeness + and self.contamination <= max_contamination + ) -def get_bins_from_directory(bin_dir: Path, set_name: str, fasta_extensions: Set[str]) -> List[Bin]: +def get_bins_from_directory( + bin_dir: Path, set_name: str, fasta_extensions: Set[str] +) -> List[Bin]: """ Retrieves a list of Bin objects from a directory containing bin FASTA files. @@ -216,14 +225,22 @@ def get_bins_from_directory(bin_dir: Path, set_name: str, fasta_extensions: Set[ :return: A list of Bin objects created from the bin FASTA files. """ bins = [] - fasta_extensions |= {f".{ext}" for ext in fasta_extensions if not ext.startswith(".")} # adding a dot in case given extension are lacking one - bin_fasta_files = (fasta_file for fasta_file in bin_dir.glob("*") if set(fasta_file.suffixes) & fasta_extensions) + fasta_extensions |= { + f".{ext}" for ext in fasta_extensions if not ext.startswith(".") + } # adding a dot in case given extension are lacking one + bin_fasta_files = ( + fasta_file + for fasta_file in bin_dir.glob("*") + if set(fasta_file.suffixes) & fasta_extensions + ) for bin_fasta_path in bin_fasta_files: bin_name = bin_fasta_path.name - contigs = {name for name, _ in pyfastx.Fasta(str(bin_fasta_path), build_index=False)} + contigs = { + name for name, _ in pyfastx.Fasta(str(bin_fasta_path), build_index=False) + } bin_obj = Bin(contigs, set_name, bin_name) @@ -232,8 +249,9 @@ def get_bins_from_directory(bin_dir: Path, set_name: str, fasta_extensions: Set[ return bins - -def parse_bin_directories(bin_name_to_bin_dir: Dict[str, Path], fasta_extensions:Set[str]) -> Dict[str, Set[Bin]]: +def parse_bin_directories( + bin_name_to_bin_dir: Dict[str, Path], fasta_extensions: Set[str] +) -> Dict[str, Set[Bin]]: """ Parses multiple bin directories and returns a dictionary mapping bin names to a list of Bin objects. @@ -247,32 +265,34 @@ def parse_bin_directories(bin_name_to_bin_dir: Dict[str, Path], fasta_extensions for name, bin_dir in bin_name_to_bin_dir.items(): bins = get_bins_from_directory(bin_dir, name, fasta_extensions) set_of_bins = set(bins) - + # Calculate the number of duplicates num_duplicates = len(bins) - len(set_of_bins) - + if num_duplicates > 0: logging.warning( f'{num_duplicates} bins with identical contig compositions detected in bin set "{name}". ' - 'These bins were merged to ensure uniqueness.' + "These bins were merged to ensure uniqueness." ) # Store the unique set of bins bin_set_name_to_bins[name] = set_of_bins - return bin_set_name_to_bins -def parse_contig2bin_tables(bin_name_to_bin_tables: Dict[str, Path]) -> Dict[str, Set['Bin']]: + +def parse_contig2bin_tables( + bin_name_to_bin_tables: Dict[str, Path] +) -> Dict[str, Set["Bin"]]: """ Parses multiple contig-to-bin tables and returns a dictionary mapping bin names to a set of unique Bin objects. Logs a warning if duplicate bins are detected within a bin set. - :param bin_name_to_bin_tables: A dictionary where keys are bin set names and values are file paths or identifiers + :param bin_name_to_bin_tables: A dictionary where keys are bin set names and values are file paths or identifiers for contig-to-bin tables. Each table is parsed to extract Bin objects. - :return: A dictionary where keys are bin set names and values are sets of Bin objects. Duplicates are removed based + :return: A dictionary where keys are bin set names and values are sets of Bin objects. Duplicates are removed based on contig composition. """ bin_set_name_to_bins = {} @@ -280,19 +300,19 @@ def parse_contig2bin_tables(bin_name_to_bin_tables: Dict[str, Path]) -> Dict[str for name, contig2bin_table in bin_name_to_bin_tables.items(): bins = get_bins_from_contig2bin_table(contig2bin_table, name) set_of_bins = set(bins) - + # Calculate the number of duplicates num_duplicates = len(bins) - len(set_of_bins) - + if num_duplicates > 0: logging.warning( f'{num_duplicates*2} bins with identical contig compositions detected in bin set "{name}". ' - 'These bins were merged to ensure uniqueness.' + "These bins were merged to ensure uniqueness." ) # Store the unique set of bins bin_set_name_to_bins[name] = set_of_bins - + return bin_set_name_to_bins @@ -322,7 +342,9 @@ def get_bins_from_contig2bin_table(contig2bin_table: Path, set_name: str) -> Lis return bins -def from_bin_sets_to_bin_graph(bin_name_to_bin_set: Mapping[str, Iterable[Bin]]) -> nx.Graph: +def from_bin_sets_to_bin_graph( + bin_name_to_bin_set: Mapping[str, Iterable[Bin]] +) -> nx.Graph: """ Creates a bin graph from a dictionary of bin sets. @@ -343,7 +365,6 @@ def from_bin_sets_to_bin_graph(bin_name_to_bin_set: Mapping[str, Iterable[Bin]]) return G - def get_all_possible_combinations(clique: List) -> Iterable[Tuple]: """ Generates all possible combinations of elements from a given clique. @@ -352,7 +373,9 @@ def get_all_possible_combinations(clique: List) -> Iterable[Tuple]: :return: An iterable of tuples representing all possible combinations of elements from the clique. """ - return (c for r in range(2, len(clique) + 1) for c in itertools.combinations(clique, r)) + return ( + c for r in range(2, len(clique) + 1) for c in itertools.combinations(clique, r) + ) def get_intersection_bins(G: nx.Graph) -> Set[Bin]: @@ -369,8 +392,12 @@ def get_intersection_bins(G: nx.Graph) -> Set[Bin]: bins_combinations = get_all_possible_combinations(clique) for bins in bins_combinations: if max((b.completeness for b in bins)) < 20: - logging.debug("completeness is not good enough to create a new bin on intersection") - logging.debug(f"{[(str(b), b.completeness, b.contamination) for b in bins]}") + logging.debug( + "completeness is not good enough to create a new bin on intersection" + ) + logging.debug( + f"{[(str(b), b.completeness, b.contamination) for b in bins]}" + ) continue intersec_bin = bins[0].intersection(*bins[1:]) @@ -390,7 +417,7 @@ def get_difference_bins(G: nx.Graph) -> Set[Bin]: :return: A set of Bin objects representing the difference bins. """ difference_bins = set() - + for clique in nx.clique.find_cliques(G): bins_combinations = get_all_possible_combinations(clique) @@ -400,8 +427,12 @@ def get_difference_bins(G: nx.Graph) -> Set[Bin]: bin_diff = bin_a.difference(*(b for b in bins if b != bin_a)) if bin_a.completeness < 20: - logging.debug(f"completeness of {bin_a} is not good enough to do difference... ") - logging.debug(f"{[(str(b), b.completeness, b.contamination) for b in bins]}") + logging.debug( + f"completeness of {bin_a} is not good enough to do difference... " + ) + logging.debug( + f"{[(str(b), b.completeness, b.contamination) for b in bins]}" + ) continue if bin_diff.contigs: @@ -424,8 +455,12 @@ def get_union_bins(G: nx.Graph, max_conta: int = 50) -> Set[Bin]: bins_combinations = get_all_possible_combinations(clique) for bins in bins_combinations: if max((b.contamination for b in bins)) > max_conta: - logging.debug("Some bin are too contaminated to make a useful union bin") - logging.debug(f"{[(str(b), b.completeness, b.contamination) for b in bins]}") + logging.debug( + "Some bin are too contaminated to make a useful union bin" + ) + logging.debug( + f"{[(str(b), b.completeness, b.contamination) for b in bins]}" + ) continue bins = set(bins) @@ -463,7 +498,8 @@ def select_best_bins(bins: Set[Bin]) -> List[Bin]: logging.info(f"Selected {len(selected_bins)} bins") return selected_bins -def group_identical_bins(bins:Iterable[Bin]) -> List[List[Bin]]: + +def group_identical_bins(bins: Iterable[Bin]) -> List[List[Bin]]: """ Group identical bins together @@ -480,7 +516,7 @@ def group_identical_bins(bins:Iterable[Bin]) -> List[List[Bin]]: return list(binhash_to_bins.values()) -def dereplicate_bin_sets(bin_sets: Iterable[Set['Bin']]) -> Set['Bin']: +def dereplicate_bin_sets(bin_sets: Iterable[Set["Bin"]]) -> Set["Bin"]: """ Consolidate bins from multiple bin sets into a single set of non-redundant bins. @@ -511,6 +547,7 @@ def dereplicate_bin_sets(bin_sets: Iterable[Set['Bin']]) -> Set['Bin']: return dereplicated_bins + def get_contigs_in_bin_sets(bin_set_name_to_bins: Dict[str, Set[Bin]]) -> Set[str]: """ Processes bin sets to check for duplicated contigs and logs detailed information about each bin set. @@ -526,13 +563,22 @@ def get_contigs_in_bin_sets(bin_set_name_to_bins: Dict[str, Set[Bin]]) -> Set[st list_contigs_in_bin_sets = get_contigs_in_bins(bins) # Count duplicates - contig_counts = {contig: list_contigs_in_bin_sets.count(contig) for contig in list_contigs_in_bin_sets} - duplicated_contigs = {contig: count for contig, count in contig_counts.items() if count > 1} + contig_counts = { + contig: list_contigs_in_bin_sets.count(contig) + for contig in list_contigs_in_bin_sets + } + duplicated_contigs = { + contig: count for contig, count in contig_counts.items() if count > 1 + } if duplicated_contigs: logging.warning( f"Bin set '{bin_set_name}' contains {len(duplicated_contigs)} duplicated contigs. " - "Details: " + ", ".join(f"{contig} (found {count} times)" for contig, count in duplicated_contigs.items()) + "Details: " + + ", ".join( + f"{contig} (found {count} times)" + for contig, count in duplicated_contigs.items() + ) ) # Unique contigs in current bin set @@ -571,7 +617,10 @@ def rename_bin_contigs(bins: Iterable[Bin], contig_to_index: dict): b.contigs = {contig_to_index[contig] for contig in b.contigs} b.hash = hash(str(sorted(b.contigs))) -def create_intermediate_bins(bin_set_name_to_bins: Mapping[str, Iterable[Bin]]) -> Set[Bin]: + +def create_intermediate_bins( + bin_set_name_to_bins: Mapping[str, Iterable[Bin]] +) -> Set[Bin]: """ Creates intermediate bins from a dictionary of bin sets. @@ -595,4 +644,3 @@ def create_intermediate_bins(bin_set_name_to_bins: Mapping[str, Iterable[Bin]]) logging.info(f"{len(union_bins)} bins created on unions.") return difference_bins | intersection_bins | union_bins - diff --git a/binette/bin_quality.py b/binette/bin_quality.py index 5a46ab2..154660d 100644 --- a/binette/bin_quality.py +++ b/binette/bin_quality.py @@ -18,7 +18,12 @@ from checkm2 import keggData, modelPostprocessing, modelProcessing # noqa: E402 -def get_bins_metadata_df(bins: Iterable[Bin], contig_to_cds_count: Dict[str, int], contig_to_aa_counter: Dict[str, Counter], contig_to_aa_length: Dict[str, int]) -> pd.DataFrame: +def get_bins_metadata_df( + bins: Iterable[Bin], + contig_to_cds_count: Dict[str, int], + contig_to_aa_counter: Dict[str, Counter], + contig_to_aa_length: Dict[str, int], +) -> pd.DataFrame: """ Generate a DataFrame containing metadata for a list of bins. @@ -36,8 +41,20 @@ def get_bins_metadata_df(bins: Iterable[Bin], contig_to_cds_count: Dict[str, int for bin_obj in bins: bin_metadata = { "Name": bin_obj.id, - "CDS": sum((contig_to_cds_count[c] for c in bin_obj.contigs if c in contig_to_cds_count)), - "AALength": sum((contig_to_aa_length[c] for c in bin_obj.contigs if c in contig_to_aa_length)), + "CDS": sum( + ( + contig_to_cds_count[c] + for c in bin_obj.contigs + if c in contig_to_cds_count + ) + ), + "AALength": sum( + ( + contig_to_aa_length[c] + for c in bin_obj.contigs + if c in contig_to_aa_length + ) + ), } bin_aa_counter = Counter() @@ -57,7 +74,10 @@ def get_bins_metadata_df(bins: Iterable[Bin], contig_to_cds_count: Dict[str, int metadata_df = metadata_df.set_index("Name", drop=False) return metadata_df -def get_diamond_feature_per_bin_df(bins: Iterable[Bin], contig_to_kegg_counter: Dict[str, Counter]) -> Tuple[pd.DataFrame, int]: + +def get_diamond_feature_per_bin_df( + bins: Iterable[Bin], contig_to_kegg_counter: Dict[str, Counter] +) -> Tuple[pd.DataFrame, int]: """ Generate a DataFrame containing Diamond feature counts per bin and completeness information for pathways, categories, and modules. @@ -83,7 +103,9 @@ def get_diamond_feature_per_bin_df(bins: Iterable[Bin], contig_to_kegg_counter: bin_to_ko_counter[bin_obj.id] = bin_ko_counter - ko_count_per_bin_df = pd.DataFrame(bin_to_ko_counter, index=defaultKOs).transpose().fillna(0) + ko_count_per_bin_df = ( + pd.DataFrame(bin_to_ko_counter, index=defaultKOs).transpose().fillna(0) + ) ko_count_per_bin_df = ko_count_per_bin_df.astype(int) ko_count_per_bin_df["Name"] = ko_count_per_bin_df.index @@ -92,12 +114,16 @@ def get_diamond_feature_per_bin_df(bins: Iterable[Bin], contig_to_kegg_counter: KO_pathways = KeggCalc.calculate_KO_group("KO_Pathways", ko_count_per_bin_df.copy()) logging.debug("Calculating category completeness information") - KO_categories = KeggCalc.calculate_KO_group("KO_Categories", ko_count_per_bin_df.copy()) + KO_categories = KeggCalc.calculate_KO_group( + "KO_Categories", ko_count_per_bin_df.copy() + ) logging.debug("Calculating module completeness information") KO_modules = KeggCalc.calculate_module_completeness(ko_count_per_bin_df.copy()) - diamond_complete_results = pd.concat([ko_count_per_bin_df, KO_pathways, KO_modules, KO_categories], axis=1) + diamond_complete_results = pd.concat( + [ko_count_per_bin_df, KO_pathways, KO_modules, KO_categories], axis=1 + ) return diamond_complete_results, len(defaultKOs) @@ -121,7 +147,8 @@ def compute_N50(list_of_lengths) -> int: cum_length += length return length -def add_bin_size_and_N50(bins: Iterable[Bin], contig_to_size: Dict[str,int]): + +def add_bin_size_and_N50(bins: Iterable[Bin], contig_to_size: Dict[str, int]): """ Add bin size and N50 to a list of bin objects. @@ -136,7 +163,9 @@ def add_bin_size_and_N50(bins: Iterable[Bin], contig_to_size: Dict[str,int]): bin_obj.add_N50(n50) -def add_bin_metrics(bins: Set[Bin], contig_info: Dict, contamination_weight: float, threads: int = 1): +def add_bin_metrics( + bins: Set[Bin], contig_info: Dict, contamination_weight: float, threads: int = 1 +): """ Add metrics to a Set of bins. @@ -184,15 +213,17 @@ def chunks(iterable: Iterable, size: int) -> Iterator[Tuple]: return iter(lambda: tuple(islice(it, size)), ()) -def assess_bins_quality_by_chunk(bins: Iterable[Bin], +def assess_bins_quality_by_chunk( + bins: Iterable[Bin], contig_to_kegg_counter: Dict, contig_to_cds_count: Dict, contig_to_aa_counter: Dict, contig_to_aa_length: Dict, contamination_weight: float, - postProcessor:Optional[modelPostprocessing.modelProcessor] = None, + postProcessor: Optional[modelPostprocessing.modelProcessor] = None, threads: int = 1, - chunk_size: int = 2500): + chunk_size: int = 2500, +): """ Assess the quality of bins in chunks. @@ -206,7 +237,7 @@ def assess_bins_quality_by_chunk(bins: Iterable[Bin], :param contamination_weight: Weight for contamination assessment. :param postProcessor: post-processor from checkm2 :param threads: Number of threads for parallel processing (default is 1). - :param chunk_size: The size of each chunk. + :param chunk_size: The size of each chunk. """ for i, chunk_bins_iter in enumerate(chunks(bins, chunk_size)): @@ -214,15 +245,16 @@ def assess_bins_quality_by_chunk(bins: Iterable[Bin], logging.debug(f"chunk {i}: assessing quality of {len(chunk_bins)}") assess_bins_quality( bins=chunk_bins, - contig_to_kegg_counter= contig_to_kegg_counter, + contig_to_kegg_counter=contig_to_kegg_counter, contig_to_cds_count=contig_to_cds_count, contig_to_aa_counter=contig_to_aa_counter, contig_to_aa_length=contig_to_aa_length, contamination_weight=contamination_weight, postProcessor=postProcessor, - threads=threads + threads=threads, ) + def assess_bins_quality( bins: Iterable[Bin], contig_to_kegg_counter: Dict, @@ -231,11 +263,12 @@ def assess_bins_quality( contig_to_aa_length: Dict, contamination_weight: float, postProcessor: Optional[modelPostprocessing.modelProcessor] = None, - threads: int = 1,): + threads: int = 1, +): """ Assess the quality of bins. - This function assesses the quality of bins based on various criteria and assigns completeness and contamination scores. + This function assesses the quality of bins based on various criteria and assigns completeness and contamination scores. This code is taken from checkm2 and adjusted :param bins: List of bin objects. @@ -250,9 +283,13 @@ def assess_bins_quality( if postProcessor is None: postProcessor = modelPostprocessing.modelProcessor(threads) - metadata_df = get_bins_metadata_df(bins, contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length) + metadata_df = get_bins_metadata_df( + bins, contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length + ) - diamond_complete_results, ko_list_length = get_diamond_feature_per_bin_df(bins, contig_to_kegg_counter) + diamond_complete_results, ko_list_length = get_diamond_feature_per_bin_df( + bins, contig_to_kegg_counter + ) diamond_complete_results = diamond_complete_results.drop(columns=["Name"]) feature_vectors = pd.concat([metadata_df, diamond_complete_results], axis=1) @@ -264,20 +301,30 @@ def assess_bins_quality( vector_array = feature_vectors.iloc[:, 1:].values.astype(np.float) logging.info("Predicting completeness and contamination using the general model.") - general_results_comp, general_results_cont = modelProc.run_prediction_general(vector_array) + general_results_comp, general_results_cont = modelProc.run_prediction_general( + vector_array + ) logging.info("Predicting completeness using the specific model.") - specific_model_vector_len = ( - ko_list_length + len(metadata_df.columns) - ) - 1 + specific_model_vector_len = (ko_list_length + len(metadata_df.columns)) - 1 # also retrieve scaled data for CSM calculations - specific_results_comp, scaled_features = modelProc.run_prediction_specific(vector_array, specific_model_vector_len) + specific_results_comp, scaled_features = modelProc.run_prediction_specific( + vector_array, specific_model_vector_len + ) - logging.info("Using cosine similarity to reference data to select an appropriate predictor model.") + logging.info( + "Using cosine similarity to reference data to select an appropriate predictor model." + ) - final_comp, final_cont, models_chosen, csm_array = postProcessor.calculate_general_specific_ratio( - vector_array[:, 20], scaled_features, general_results_comp, general_results_cont, specific_results_comp + final_comp, final_cont, models_chosen, csm_array = ( + postProcessor.calculate_general_specific_ratio( + vector_array[:, 20], + scaled_features, + general_results_comp, + general_results_cont, + specific_results_comp, + ) ) final_results = feature_vectors[["Name"]].copy() diff --git a/binette/cds.py b/binette/cds.py index 8294641..cc125c5 100644 --- a/binette/cds.py +++ b/binette/cds.py @@ -1,4 +1,3 @@ - import concurrent.futures as cf import multiprocessing.pool import logging @@ -21,7 +20,10 @@ def get_contig_from_cds_name(cds_name: str) -> str: """ return "_".join(cds_name.split("_")[:-1]) -def predict(contigs_iterator: Iterator, outfaa: str, threads: int =1) -> Dict[str, List[str]]: + +def predict( + contigs_iterator: Iterator, outfaa: str, threads: int = 1 +) -> Dict[str, List[str]]: """ Predict open reading frames with Pyrodigal. @@ -33,17 +35,18 @@ def predict(contigs_iterator: Iterator, outfaa: str, threads: int =1) -> Dict[st """ try: # for version >=3 of pyrodigal - orf_finder = pyrodigal.GeneFinder(meta="meta") # type: ignore + orf_finder = pyrodigal.GeneFinder(meta="meta") # type: ignore except AttributeError: - orf_finder = pyrodigal.OrfFinder(meta="meta") # type: ignore + orf_finder = pyrodigal.OrfFinder(meta="meta") # type: ignore logging.info(f"Predicting cds sequences with Pyrodigal using {threads} threads.") - - with multiprocessing.pool.ThreadPool(processes=threads) as pool: - contig_and_genes = pool.starmap(predict_genes, ((orf_finder.find_genes, seq) for seq in contigs_iterator)) + + with multiprocessing.pool.ThreadPool(processes=threads) as pool: + contig_and_genes = pool.starmap( + predict_genes, ((orf_finder.find_genes, seq) for seq in contigs_iterator) + ) write_faa(outfaa, contig_and_genes) - contig_to_genes = { contig_id: [gene.translate() for gene in pyrodigal_genes] @@ -52,10 +55,10 @@ def predict(contigs_iterator: Iterator, outfaa: str, threads: int =1) -> Dict[st return contig_to_genes -def predict_genes(find_genes, seq) -> Tuple[str, pyrodigal.Genes]: +def predict_genes(find_genes, seq) -> Tuple[str, pyrodigal.Genes]: - return (seq.name, find_genes(seq.seq) ) + return (seq.name, find_genes(seq.seq)) def write_faa(outfaa: str, contig_to_genes: List[Tuple[str, pyrodigal.Genes]]) -> None: @@ -80,16 +83,15 @@ def is_nucleic_acid(sequence: str) -> bool: :return: True if the sequence is a DNA or RNA sequence, False otherwise. """ # Define nucleotidic bases (DNA and RNA) - nucleotidic_bases = set('ATCGNUatcgnu') - + nucleotidic_bases = set("ATCGNUatcgnu") + # Check if all characters in the sequence are valid nucleotidic bases (DNA or RNA) if all(base in nucleotidic_bases for base in sequence): return True - + # If any character is invalid, return False return False - def parse_faa_file(faa_file: str) -> Dict[str, List[str]]: """ @@ -106,7 +108,7 @@ def parse_faa_file(faa_file: str) -> Dict[str, List[str]]: for name, seq in pyfastx.Fastx(faa_file): contig = get_contig_from_cds_name(name) contig_to_genes[contig].append(seq) - + # Concatenate up to the first 20 sequences for validation if len(checked_sequences) < 20: checked_sequences.append(seq) @@ -122,7 +124,6 @@ def parse_faa_file(faa_file: str) -> Dict[str, List[str]]: ) return dict(contig_to_genes) - def get_aa_composition(genes: List[str]) -> Counter: @@ -138,24 +139,38 @@ def get_aa_composition(genes: List[str]) -> Counter: return aa_counter -def get_contig_cds_metadata_flat(contig_to_genes: Dict[str, List[str]]) -> Tuple[Dict[str, int], Dict[str, Counter], Dict[str, int]]: + +def get_contig_cds_metadata_flat( + contig_to_genes: Dict[str, List[str]] +) -> Tuple[Dict[str, int], Dict[str, Counter], Dict[str, int]]: """ Calculate metadata for contigs, including CDS count, amino acid composition, and total amino acid length. :param contig_to_genes: A dictionary mapping contig names to lists of protein sequences. :return: A tuple containing dictionaries for CDS count, amino acid composition, and total amino acid length. """ - contig_to_cds_count = {contig: len(genes) for contig, genes in contig_to_genes.items()} + contig_to_cds_count = { + contig: len(genes) for contig, genes in contig_to_genes.items() + } - contig_to_aa_counter = {contig: get_aa_composition(genes) for contig, genes in tqdm(contig_to_genes.items(), unit="contig")} + contig_to_aa_counter = { + contig: get_aa_composition(genes) + for contig, genes in tqdm(contig_to_genes.items(), unit="contig") + } logging.info("Calculating amino acid composition.") - contig_to_aa_length = {contig: sum(counter.values()) for contig, counter in tqdm(contig_to_aa_counter.items(), unit="contig")} + contig_to_aa_length = { + contig: sum(counter.values()) + for contig, counter in tqdm(contig_to_aa_counter.items(), unit="contig") + } logging.info("Calculating total amino acid length.") return contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length -def get_contig_cds_metadata(contig_to_genes: Dict[int, Union[Any, List[Any]]], threads: int) -> Dict[str, Dict]: + +def get_contig_cds_metadata( + contig_to_genes: Dict[int, Union[Any, List[Any]]], threads: int +) -> Dict[str, Dict]: """ Calculate metadata for contigs in parallel, including CDS count, amino acid composition, and total amino acid length. @@ -163,7 +178,9 @@ def get_contig_cds_metadata(contig_to_genes: Dict[int, Union[Any, List[Any]]], :param threads: Number of CPU threads to use. :return: A tuple containing dictionaries for CDS count, amino acid composition, and total amino acid length. """ - contig_to_cds_count = {contig: len(genes) for contig, genes in contig_to_genes.items()} + contig_to_cds_count = { + contig: len(genes) for contig, genes in contig_to_genes.items() + } contig_to_future = {} logging.info(f"Collecting contig amino acid composition using {threads} threads.") @@ -171,10 +188,16 @@ def get_contig_cds_metadata(contig_to_genes: Dict[int, Union[Any, List[Any]]], for contig, genes in tqdm(contig_to_genes.items()): contig_to_future[contig] = tpe.submit(get_aa_composition, genes) - contig_to_aa_counter = {contig: future.result() for contig, future in tqdm(contig_to_future.items(), unit="contig")} + contig_to_aa_counter = { + contig: future.result() + for contig, future in tqdm(contig_to_future.items(), unit="contig") + } logging.info("Calculating amino acid composition in parallel.") - contig_to_aa_length = {contig: sum(counter.values()) for contig, counter in tqdm(contig_to_aa_counter.items(), unit="contig")} + contig_to_aa_length = { + contig: sum(counter.values()) + for contig, counter in tqdm(contig_to_aa_counter.items(), unit="contig") + } logging.info("Calculating total amino acid length in parallel.") contig_info = { @@ -183,4 +206,4 @@ def get_contig_cds_metadata(contig_to_genes: Dict[int, Union[Any, List[Any]]], "contig_to_aa_length": contig_to_aa_length, } - return contig_info \ No newline at end of file + return contig_info diff --git a/binette/contig_manager.py b/binette/contig_manager.py index 4b43733..fc783f5 100644 --- a/binette/contig_manager.py +++ b/binette/contig_manager.py @@ -27,7 +27,9 @@ def make_contig_index(contigs: Set[str]) -> Tuple[Dict[str, int], Dict[int, str] return contig_to_index, index_to_contig -def apply_contig_index(contig_to_index: Dict[str, int], contig_to_info: Dict[str, Any]) -> Dict[int, Union[Any,Iterable[Any]]]: +def apply_contig_index( + contig_to_index: Dict[str, int], contig_to_info: Dict[str, Any] +) -> Dict[int, Union[Any, Iterable[Any]]]: """ Apply the contig index mapping to the contig info dictionary. diff --git a/binette/diamond.py b/binette/diamond.py index 500f090..33fc834 100644 --- a/binette/diamond.py +++ b/binette/diamond.py @@ -19,16 +19,22 @@ def get_checkm2_db() -> str: logging.error("Make sure checkm2 is on your system path.") sys.exit(1) - checkm2_database_raw = subprocess.run(["checkm2", "database", "--current"], text=True, stderr=subprocess.PIPE) + checkm2_database_raw = subprocess.run( + ["checkm2", "database", "--current"], text=True, stderr=subprocess.PIPE + ) if checkm2_database_raw.returncode != 0: - logging.error(f"Something went wrong with checkm2:\n=======\n{checkm2_database_raw.stderr}========") + logging.error( + f"Something went wrong with checkm2:\n=======\n{checkm2_database_raw.stderr}========" + ) sys.exit(1) reg_result = re.search("INFO: (/.*.dmnd)", checkm2_database_raw.stderr) if reg_result is None: - logging.error(f"Something went wrong when retrieving checkm2 db path:\n{checkm2_database_raw.stderr}") + logging.error( + f"Something went wrong when retrieving checkm2 db path:\n{checkm2_database_raw.stderr}" + ) sys.exit(1) else: db_path = reg_result.group(1) @@ -36,7 +42,6 @@ def get_checkm2_db() -> str: return db_path - def check_tool_exists(tool_name: str): """ Check if a specified tool is on the system's PATH. @@ -46,12 +51,23 @@ def check_tool_exists(tool_name: str): :raises FileNotFoundError: If the tool is not found on the system's PATH. """ if shutil.which(tool_name) is None: - raise FileNotFoundError(f"The '{tool_name}' tool is not found on your system PATH.") + raise FileNotFoundError( + f"The '{tool_name}' tool is not found on your system PATH." + ) def run( - faa_file: str, output: str, db: str, log: str, threads: int = 1, query_cover: int = 80, subject_cover: int = 80, - percent_id: int = 30, evalue: float = 1e-05, low_mem: bool = False): + faa_file: str, + output: str, + db: str, + log: str, + threads: int = 1, + query_cover: int = 80, + subject_cover: int = 80, + percent_id: int = 30, + evalue: float = 1e-05, + low_mem: bool = False, +): """ Run Diamond with specified parameters. @@ -93,6 +109,7 @@ def run( logging.info("Finished Running DIAMOND") + def get_contig_to_kegg_id(diamond_result_file: str) -> dict: """ Get a dictionary mapping contig IDs to KEGG annotations from a Diamond result file. @@ -100,21 +117,31 @@ def get_contig_to_kegg_id(diamond_result_file: str) -> dict: :param diamond_result_file: Path to the Diamond result file. :return: A dictionary mapping contig IDs to KEGG annotations. """ - diamon_results_df = pd.read_csv(diamond_result_file, sep="\t", usecols=[0, 1], names=["ProteinID", "annotation"]) - diamon_results_df[["Ref100_hit", "Kegg_annotation"]] = diamon_results_df["annotation"].str.split( - "~", n=1, expand=True + diamon_results_df = pd.read_csv( + diamond_result_file, sep="\t", usecols=[0, 1], names=["ProteinID", "annotation"] ) + diamon_results_df[["Ref100_hit", "Kegg_annotation"]] = diamon_results_df[ + "annotation" + ].str.split("~", n=1, expand=True) KeggCalc = keggData.KeggCalculator() defaultKOs = KeggCalc.return_default_values_from_category("KO_Genes") - diamon_results_df = diamon_results_df.loc[diamon_results_df["Kegg_annotation"].isin(defaultKOs.keys())] - diamon_results_df["contig"] = diamon_results_df["ProteinID"].str.split("_", n=-1).str[:-1].str.join("_") + diamon_results_df = diamon_results_df.loc[ + diamon_results_df["Kegg_annotation"].isin(defaultKOs.keys()) + ] + diamon_results_df["contig"] = ( + diamon_results_df["ProteinID"].str.split("_", n=-1).str[:-1].str.join("_") + ) contig_to_kegg_counter = ( - diamon_results_df.groupby("contig").agg({"Kegg_annotation": Counter}).reset_index() + diamon_results_df.groupby("contig") + .agg({"Kegg_annotation": Counter}) + .reset_index() ) - contig_to_kegg_counter = dict(zip(contig_to_kegg_counter["contig"], contig_to_kegg_counter["Kegg_annotation"])) + contig_to_kegg_counter = dict( + zip(contig_to_kegg_counter["contig"], contig_to_kegg_counter["Kegg_annotation"]) + ) return contig_to_kegg_counter diff --git a/binette/io_manager.py b/binette/io_manager.py index 7dc1ca4..2c1aafe 100644 --- a/binette/io_manager.py +++ b/binette/io_manager.py @@ -7,7 +7,10 @@ from pathlib import Path -def get_paths_common_prefix_suffix(paths: List[Path]) -> Tuple[List[str], List[str], List[str]]: + +def get_paths_common_prefix_suffix( + paths: List[Path], +) -> Tuple[List[str], List[str], List[str]]: """ Determine the common prefix parts, suffix parts, and common extensions of the last part of a list of pathlib.Path objects. @@ -19,25 +22,33 @@ def get_paths_common_prefix_suffix(paths: List[Path]) -> Tuple[List[str], List[s """ # Extract parts for all paths parts = [list(path.parts) for path in paths] - + # Find the common prefix if not parts: return [], [], [] - + # Initialize common prefix and suffix lists common_prefix = list(parts[0]) common_suffix = list(parts[0]) # Determine common prefix for part_tuple in parts[1:]: common_prefix_length = min(len(common_prefix), len(part_tuple)) - common_prefix = [common_prefix[i] for i in range(common_prefix_length) if common_prefix[:i+1] == part_tuple[:i+1]] + common_prefix = [ + common_prefix[i] + for i in range(common_prefix_length) + if common_prefix[: i + 1] == part_tuple[: i + 1] + ] if not common_prefix: break # Determine common suffix for part_tuple in parts[1:]: common_suffix_length = min(len(common_suffix), len(part_tuple)) - common_suffix = [common_suffix[-i] for i in range(1, common_suffix_length + 1) if common_suffix[-i:] == part_tuple[-i:]] + common_suffix = [ + common_suffix[-i] + for i in range(1, common_suffix_length + 1) + if common_suffix[-i:] == part_tuple[-i:] + ] if not common_suffix: break if len(parts) > 1: @@ -50,12 +61,17 @@ def get_paths_common_prefix_suffix(paths: List[Path]) -> Tuple[List[str], List[s common_extensions = list(paths[0].suffixes) for path in paths[1:]: common_extension_length = min(len(common_extensions), len(path.suffixes)) - common_extensions = [common_extensions[i] for i in range(common_extension_length) if common_extensions[i] == path.suffixes[i]] + common_extensions = [ + common_extensions[i] + for i in range(common_extension_length) + if common_extensions[i] == path.suffixes[i] + ] if not common_extensions: break - + return common_prefix, common_suffix, common_extensions - + + def infer_bin_set_names_from_input_paths(input_bins: List[Path]) -> Dict[str, Path]: """ Infer bin set names from a list of bin input directories or files. @@ -65,18 +81,23 @@ def infer_bin_set_names_from_input_paths(input_bins: List[Path]) -> Dict[str, Pa """ bin_name_to_bin_dir = {} - common_prefix, common_suffix, common_extensions = get_paths_common_prefix_suffix(input_bins) + common_prefix, common_suffix, common_extensions = get_paths_common_prefix_suffix( + input_bins + ) for path in input_bins: - specific_parts = path.parts[len(common_prefix):len(path.parts)-len(common_suffix)] + specific_parts = path.parts[ + len(common_prefix) : len(path.parts) - len(common_suffix) + ] if not common_suffix and common_extensions: - last_specific_part = specific_parts[-1].split('.')[:-len(common_extensions)] + last_specific_part = specific_parts[-1].split(".")[ + : -len(common_extensions) + ] specific_parts = list(specific_parts[:-1]) + last_specific_part - - bin_set_name = '/'.join(specific_parts) + bin_set_name = "/".join(specific_parts) if bin_set_name == "": bin_set_name = path.as_posix() @@ -100,15 +121,25 @@ def write_bin_info(bins: Iterable[Bin], output: Path, add_contigs: bool = False) :param add_contigs: Flag indicating whether to include contig information. """ - header = ["bin_id", "origin", "name", "completeness", "contamination", "score", "size", "N50", "contig_count"] + header = [ + "bin_id", + "origin", + "name", + "completeness", + "contamination", + "score", + "size", + "N50", + "contig_count", + ] if add_contigs: - header.append('contigs') + header.append("contigs") bin_infos = [] for bin_obj in sorted(bins, key=lambda x: (x.score, x.N50, -x.id), reverse=True): bin_info = [ bin_obj.id, - ';'.join(bin_obj.origin), + ";".join(bin_obj.origin), bin_obj.name, bin_obj.completeness, bin_obj.contamination, @@ -118,7 +149,9 @@ def write_bin_info(bins: Iterable[Bin], output: Path, add_contigs: bool = False) len(bin_obj.contigs), ] if add_contigs: - bin_info.append(";".join(str(c) for c in bin_obj.contigs) if add_contigs else "") + bin_info.append( + ";".join(str(c) for c in bin_obj.contigs) if add_contigs else "" + ) bin_infos.append(bin_info) @@ -147,10 +180,12 @@ def write_bins_fasta(selected_bins: List[Bin], contigs_fasta: Path, outdir: Path outfl.write("\n".join(sequences) + "\n") -def check_contig_consistency(contigs_from_assembly: Iterable[str], - contigs_from_elsewhere: Iterable[str], - assembly_file: str, - elsewhere_file: str ): +def check_contig_consistency( + contigs_from_assembly: Iterable[str], + contigs_from_elsewhere: Iterable[str], + assembly_file: str, + elsewhere_file: str, +): """ Check the consistency of contig names between different sources. @@ -161,14 +196,16 @@ def check_contig_consistency(contigs_from_assembly: Iterable[str], :raises AssertionError: If inconsistencies in contig names are found. """ logging.debug("check_contig_consistency.") - are_contigs_consistent = len(set(contigs_from_elsewhere) | set(contigs_from_assembly)) <= len( - set(contigs_from_assembly) - ) + are_contigs_consistent = len( + set(contigs_from_elsewhere) | set(contigs_from_assembly) + ) <= len(set(contigs_from_assembly)) issue_countigs = len(set(contigs_from_elsewhere) - set(contigs_from_assembly)) - - message = (f"{issue_countigs} contigs found in file '{elsewhere_file}' " - f"were not found in assembly_file '{assembly_file}'") + + message = ( + f"{issue_countigs} contigs found in file '{elsewhere_file}' " + f"were not found in assembly_file '{assembly_file}'" + ) assert are_contigs_consistent, message @@ -185,7 +222,9 @@ def check_resume_file(faa_file: Path, diamond_result_file: Path) -> None: return if not faa_file.exists(): - error_msg = f"Protein file '{faa_file}' does not exist. Resuming is not possible." + error_msg = ( + f"Protein file '{faa_file}' does not exist. Resuming is not possible." + ) logging.error(error_msg) raise FileNotFoundError(error_msg) @@ -195,7 +234,9 @@ def check_resume_file(faa_file: Path, diamond_result_file: Path) -> None: raise FileNotFoundError(error_msg) -def write_original_bin_metrics(bin_set_name_to_bins: Dict[str, Set[Bin]], original_bin_report_dir: Path): +def write_original_bin_metrics( + bin_set_name_to_bins: Dict[str, Set[Bin]], original_bin_report_dir: Path +): """ Write metrics of original input bins to a specified directory. @@ -203,7 +244,7 @@ def write_original_bin_metrics(bin_set_name_to_bins: Dict[str, Set[Bin]], origin the metrics for each bin set to a TSV file in the specified directory. Each bin set will have its own TSV file named according to its set name. - :param bin_set_name_to_bins: A dictionary where the keys are bin set names (str) and + :param bin_set_name_to_bins: A dictionary where the keys are bin set names (str) and the values are sets of Bin objects representing bins. :param original_bin_report_dir: The directory path (Path) where the bin metrics will be saved. """ @@ -211,9 +252,14 @@ def write_original_bin_metrics(bin_set_name_to_bins: Dict[str, Set[Bin]], origin original_bin_report_dir.mkdir(parents=True, exist_ok=True) for i, (set_name, bins) in enumerate(sorted(bin_set_name_to_bins.items())): - bins_metric_file = original_bin_report_dir / f"input_bins_{i + 1}.{set_name.replace('/', '_')}.tsv" - - logging.debug(f"Writing metrics for bin set '{set_name}' to file: {bins_metric_file}") + bins_metric_file = ( + original_bin_report_dir + / f"input_bins_{i + 1}.{set_name.replace('/', '_')}.tsv" + ) + + logging.debug( + f"Writing metrics for bin set '{set_name}' to file: {bins_metric_file}" + ) write_bin_info(bins, bins_metric_file) logging.debug("Completed writing all original input bin metrics.") diff --git a/binette/main.py b/binette/main.py index 8bd6bac..bd528d5 100755 --- a/binette/main.py +++ b/binette/main.py @@ -15,10 +15,18 @@ import os import binette -from binette import contig_manager, cds, diamond, bin_quality, bin_manager, io_manager as io +from binette import ( + contig_manager, + cds, + diamond, + bin_quality, + bin_manager, + io_manager as io, +) from typing import List, Dict, Optional, Set, Tuple, Union, Sequence, Any from pathlib import Path + def init_logging(verbose, debug): """Initialise logging.""" if debug: @@ -46,11 +54,11 @@ class UniqueStore(Action): """ def __call__( - self, - parser: ArgumentParser, - namespace: Namespace, - values: Union[str, Sequence[Any], None], - option_string: Optional[str] = None + self, + parser: ArgumentParser, + namespace: Namespace, + values: Union[str, Sequence[Any], None], + option_string: Optional[str] = None, ) -> None: """ Ensures the argument is only used once. Raises an error if the argument appears multiple times. @@ -62,8 +70,10 @@ def __call__( """ # Check if the argument has already been set if getattr(namespace, self.dest, self.default) is not self.default: - parser.error(f"Error: The argument {option_string} can only be specified once.") - + parser.error( + f"Error: The argument {option_string} can only be specified once." + ) + # Set the argument value setattr(namespace, self.dest, values) @@ -81,9 +91,10 @@ def is_valid_file(parser: ArgumentParser, arg: str) -> Path: # Check if the file exists at the provided path if not path_arg.exists(): parser.error(f"Error: The specified file '{arg}' does not exist.") - + return path_arg + def parse_arguments(args): """Parse script arguments.""" @@ -93,7 +104,7 @@ def parse_arguments(args): ) # Input arguments category - input_group = parser.add_argument_group('Input Arguments') + input_group = parser.add_argument_group("Input Arguments") input_arg = input_group.add_mutually_exclusive_group(required=True) input_arg.add_argument( @@ -115,20 +126,24 @@ def parse_arguments(args): with a tabulation: contig, bin", ) - input_group.add_argument("-c", "--contigs", required=True, - type=lambda x: is_valid_file(parser, x), - help="Contigs in fasta format.") + input_group.add_argument( + "-c", + "--contigs", + required=True, + type=lambda x: is_valid_file(parser, x), + help="Contigs in fasta format.", + ) input_group.add_argument( - "-p", "--proteins", + "-p", + "--proteins", type=lambda x: is_valid_file(parser, x), help="FASTA file of predicted proteins in Prodigal format (>contigID_geneID). " - "Skips the gene prediction step if provided." + "Skips the gene prediction step if provided.", ) - # Other parameters category - other_group = parser.add_argument_group('Other Arguments') + other_group = parser.add_argument_group("Other Arguments") other_group.add_argument( "-m", @@ -138,9 +153,13 @@ def parse_arguments(args): help="Minimum completeness required for final bin selections.", ) - other_group.add_argument("-t", "--threads", default=1, type=int, help="Number of threads to use.") + other_group.add_argument( + "-t", "--threads", default=1, type=int, help="Number of threads to use." + ) - other_group.add_argument("-o", "--outdir", default=Path("results"), type=Path, help="Output directory.") + other_group.add_argument( + "-o", "--outdir", default=Path("results"), type=Path, help="Output directory." + ) other_group.add_argument( "-w", @@ -148,9 +167,9 @@ def parse_arguments(args): default=2, type=float, help="Bin are scored as follow: completeness - weight * contamination. " - "A low contamination_weight favor complete bins over low contaminated bins.", + "A low contamination_weight favor complete bins over low contaminated bins.", ) - + other_group.add_argument( "-e", "--fasta_extensions", @@ -164,31 +183,40 @@ def parse_arguments(args): "--checkm2_db", type=Path, help="Provide a path for the CheckM2 diamond database. " - "By default the database set via is used." + "By default the database set via is used.", ) - other_group.add_argument("--low_mem", help="Use low mem mode when running diamond", action="store_true") + other_group.add_argument( + "--low_mem", help="Use low mem mode when running diamond", action="store_true" + ) - other_group.add_argument("-v", "--verbose", help="increase output verbosity", action="store_true") + other_group.add_argument( + "-v", "--verbose", help="increase output verbosity", action="store_true" + ) other_group.add_argument("--debug", help="Activate debug mode", action="store_true") - other_group.add_argument("--resume", - action="store_true", - help="Activate resume mode. Binette will examine the 'temporary_files' directory " - "within the output directory and reuse any existing files if possible." - ) - + other_group.add_argument( + "--resume", + action="store_true", + help="Activate resume mode. Binette will examine the 'temporary_files' directory " + "within the output directory and reuse any existing files if possible.", + ) other_group.add_argument("--version", action="version", version=binette.__version__) args = parser.parse_args(args) return args -def parse_input_files(bin_dirs: List[Path], - contig2bin_tables: List[Path], - contigs_fasta: Path, - fasta_extensions:Set[str] = {".fasta", ".fna", ".fa"}) -> Tuple[Dict[str, Set[bin_manager.Bin]], Set[bin_manager.Bin], Set[str], Dict[str, int]]: + +def parse_input_files( + bin_dirs: List[Path], + contig2bin_tables: List[Path], + contigs_fasta: Path, + fasta_extensions: Set[str] = {".fasta", ".fna", ".fa"}, +) -> Tuple[ + Dict[str, Set[bin_manager.Bin]], Set[bin_manager.Bin], Set[str], Dict[str, int] +]: """ Parses input files to retrieve information related to bins and contigs. @@ -207,11 +235,17 @@ def parse_input_files(bin_dirs: List[Path], if bin_dirs: logging.info("Parsing bin directories.") bin_name_to_bin_dir = io.infer_bin_set_names_from_input_paths(bin_dirs) - bin_set_name_to_bins = bin_manager.parse_bin_directories(bin_name_to_bin_dir, fasta_extensions) + bin_set_name_to_bins = bin_manager.parse_bin_directories( + bin_name_to_bin_dir, fasta_extensions + ) else: logging.info("Parsing bin2contig files.") - bin_name_to_bin_table = io.infer_bin_set_names_from_input_paths(contig2bin_tables) - bin_set_name_to_bins = bin_manager.parse_contig2bin_tables(bin_name_to_bin_table) + bin_name_to_bin_table = io.infer_bin_set_names_from_input_paths( + contig2bin_tables + ) + bin_set_name_to_bins = bin_manager.parse_contig2bin_tables( + bin_name_to_bin_table + ) logging.info(f"Processing {len(bin_set_name_to_bins)} bin sets.") for bin_set_id, bins in bin_set_name_to_bins.items(): @@ -223,25 +257,38 @@ def parse_input_files(bin_dirs: List[Path], logging.info(f"Parsing contig fasta file: {contigs_fasta}") contigs_object = contig_manager.parse_fasta_file(contigs_fasta.as_posix()) - unexpected_contigs = {contig for contig in contigs_in_bins if contig not in contigs_object} + unexpected_contigs = { + contig for contig in contigs_in_bins if contig not in contigs_object + } if len(unexpected_contigs): - raise ValueError(f"{len(unexpected_contigs)} contigs from the input bins were not found in the contigs file '{contigs_fasta}'. " - f"The missing contigs are: {', '.join(unexpected_contigs)}. Please ensure all contigs from input bins are present in contig file.") + raise ValueError( + f"{len(unexpected_contigs)} contigs from the input bins were not found in the contigs file '{contigs_fasta}'. " + f"The missing contigs are: {', '.join(unexpected_contigs)}. Please ensure all contigs from input bins are present in contig file." + ) - contig_to_length = {seq.name: len(seq) for seq in contigs_object if seq.name in contigs_in_bins} + contig_to_length = { + seq.name: len(seq) for seq in contigs_object if seq.name in contigs_in_bins + } return bin_set_name_to_bins, original_bins, contigs_in_bins, contig_to_length -def manage_protein_alignement(faa_file: Path, contigs_fasta: Path, contig_to_length: Dict[str, int], - contigs_in_bins: Set[str], diamond_result_file: Path, - checkm2_db: Optional[Path], threads: int, use_existing_protein_file: bool, - resume_diamond:bool, - low_mem: bool) -> Tuple[Dict[str, int], Dict[str, List[str]]]: +def manage_protein_alignement( + faa_file: Path, + contigs_fasta: Path, + contig_to_length: Dict[str, int], + contigs_in_bins: Set[str], + diamond_result_file: Path, + checkm2_db: Optional[Path], + threads: int, + use_existing_protein_file: bool, + resume_diamond: bool, + low_mem: bool, +) -> Tuple[Dict[str, int], Dict[str, List[str]]]: """ Predicts or reuses proteins prediction and runs diamond on them. - + :param faa_file: The path to the .faa file. :param contigs_fasta: The path to the contigs FASTA file. :param contig_to_length: Dictionary mapping contig names to their lengths. @@ -260,10 +307,19 @@ def manage_protein_alignement(faa_file: Path, contigs_fasta: Path, contig_to_len if use_existing_protein_file: logging.info(f"Parsing faa file: {faa_file}.") contig_to_genes = cds.parse_faa_file(faa_file.as_posix()) - io.check_contig_consistency(contig_to_length, contig_to_genes, contigs_fasta.as_posix(), faa_file.as_posix()) + io.check_contig_consistency( + contig_to_length, + contig_to_genes, + contigs_fasta.as_posix(), + faa_file.as_posix(), + ) else: - contigs_iterator = (s for s in contig_manager.parse_fasta_file(contigs_fasta.as_posix()) if s.name in contigs_in_bins) + contigs_iterator = ( + s + for s in contig_manager.parse_fasta_file(contigs_fasta.as_posix()) + if s.name in contigs_in_bins + ) contig_to_genes = cds.predict(contigs_iterator, faa_file.as_posix(), threads) if not resume_diamond: @@ -275,7 +331,7 @@ def manage_protein_alignement(faa_file: Path, contigs_fasta: Path, contig_to_len else: raise FileNotFoundError(checkm2_db) - diamond_log = diamond_result_file.parents[0] / f"{diamond_result_file.stem}.log" + diamond_log = diamond_result_file.parents[0] / f"{diamond_result_file.stem}.log" diamond.run( faa_file.as_posix(), @@ -287,18 +343,30 @@ def manage_protein_alignement(faa_file: Path, contigs_fasta: Path, contig_to_len ) logging.info("Parsing diamond results.") - contig_to_kegg_counter = diamond.get_contig_to_kegg_id(diamond_result_file.as_posix()) + contig_to_kegg_counter = diamond.get_contig_to_kegg_id( + diamond_result_file.as_posix() + ) # Check contigs from diamond vs input assembly consistency - io.check_contig_consistency(contig_to_length, contig_to_kegg_counter, contigs_fasta.as_posix(), diamond_result_file.as_posix()) + io.check_contig_consistency( + contig_to_length, + contig_to_kegg_counter, + contigs_fasta.as_posix(), + diamond_result_file.as_posix(), + ) return contig_to_kegg_counter, contig_to_genes -def select_bins_and_write_them(all_bins: Set[bin_manager.Bin], - contigs_fasta: Path, - final_bin_report: Path, min_completeness: float, - index_to_contig: dict, outdir: Path, debug: bool) -> List[bin_manager.Bin]: +def select_bins_and_write_them( + all_bins: Set[bin_manager.Bin], + contigs_fasta: Path, + final_bin_report: Path, + min_completeness: float, + index_to_contig: dict, + outdir: Path, + debug: bool, +) -> List[bin_manager.Bin]: """ Selects and writes bins based on specific criteria. @@ -318,29 +386,31 @@ def select_bins_and_write_them(all_bins: Set[bin_manager.Bin], if debug: all_bins_for_debug = set(all_bins) all_bin_compo_file = outdir / "all_bins_quality_reports.tsv" - + logging.info(f"Writing all bins in {all_bin_compo_file}") - + io.write_bin_info(all_bins_for_debug, all_bin_compo_file, add_contigs=True) - - with open(os.path.join(outdir, "index_to_contig.tsv"), 'w') as flout: - flout.write('\n'.join((f'{i}\t{c}' for i, c in index_to_contig.items()))) + + with open(os.path.join(outdir, "index_to_contig.tsv"), "w") as flout: + flout.write("\n".join((f"{i}\t{c}" for i, c in index_to_contig.items()))) logging.info("Selecting best bins") selected_bins = bin_manager.select_best_bins(all_bins) logging.info(f"Bin Selection: {len(selected_bins)} selected bins") - logging.info(f"Filtering bins: only bins with completeness >= {min_completeness} are kept") + logging.info( + f"Filtering bins: only bins with completeness >= {min_completeness} are kept" + ) selected_bins = [b for b in selected_bins if b.is_complete_enough(min_completeness)] logging.info(f"Filtering bins: {len(selected_bins)} selected bins") logging.info(f"Writing selected bins in {final_bin_report}") - + for b in selected_bins: b.contigs = {index_to_contig[c_index] for c_index in b.contigs} - + io.write_bin_info(selected_bins, final_bin_report) io.write_bins_fasta(selected_bins, contigs_fasta, outdir_final_bin_set) @@ -348,8 +418,11 @@ def select_bins_and_write_them(all_bins: Set[bin_manager.Bin], return selected_bins - -def log_selected_bin_info(selected_bins: List[bin_manager.Bin], hq_min_completeness: float, hq_max_conta: float): +def log_selected_bin_info( + selected_bins: List[bin_manager.Bin], + hq_min_completeness: float, + hq_max_conta: float, +): """ Log information about selected bins based on quality thresholds. @@ -364,21 +437,39 @@ def log_selected_bin_info(selected_bins: List[bin_manager.Bin], hq_min_completen # Log completeness and contamination in debug log logging.debug("High quality bins:") for sb in selected_bins: - if sb.is_high_quality(min_completeness=hq_min_completeness, max_contamination=hq_max_conta): - logging.debug(f"> {sb} completeness={sb.completeness}, contamination={sb.contamination}") + if sb.is_high_quality( + min_completeness=hq_min_completeness, max_contamination=hq_max_conta + ): + logging.debug( + f"> {sb} completeness={sb.completeness}, contamination={sb.contamination}" + ) # Count high-quality bins and single-contig high-quality bins - hq_bins = len([sb for sb in selected_bins if sb.is_high_quality(min_completeness=hq_min_completeness, max_contamination=hq_max_conta)]) + hq_bins = len( + [ + sb + for sb in selected_bins + if sb.is_high_quality( + min_completeness=hq_min_completeness, max_contamination=hq_max_conta + ) + ] + ) # Log information about high-quality bins - thresholds = f"(completeness >= {hq_min_completeness} and contamination <= {hq_max_conta})" - logging.info(f"{hq_bins}/{len(selected_bins)} selected bins have a high quality {thresholds}.") + thresholds = ( + f"(completeness >= {hq_min_completeness} and contamination <= {hq_max_conta})" + ) + logging.info( + f"{hq_bins}/{len(selected_bins)} selected bins have a high quality {thresholds}." + ) def main(): "Orchestrate the execution of the program" - args = parse_arguments(sys.argv[1:]) # sys.argv is passed in order to be able to test the function parse_arguments + args = parse_arguments( + sys.argv[1:] + ) # sys.argv is passed in order to be able to test the function parse_arguments init_logging(args.verbose, args.debug) @@ -387,9 +478,9 @@ def main(): hq_min_completeness = 90 # Temporary files # - out_tmp_dir:Path = args.outdir / "temporary_files" + out_tmp_dir: Path = args.outdir / "temporary_files" os.makedirs(out_tmp_dir, exist_ok=True) - + use_existing_protein_file = False if args.proteins: @@ -399,64 +490,91 @@ def main(): else: faa_file = out_tmp_dir / "assembly_proteins.faa" - diamond_result_file = out_tmp_dir / "diamond_result.tsv" # Output files # - final_bin_report:Path = args.outdir / "final_bins_quality_reports.tsv" - original_bin_report_dir:Path = args.outdir / "input_bins_quality_reports" + final_bin_report: Path = args.outdir / "final_bins_quality_reports.tsv" + original_bin_report_dir: Path = args.outdir / "input_bins_quality_reports" if args.resume: io.check_resume_file(faa_file, diamond_result_file) use_existing_protein_file = True - bin_set_name_to_bins, original_bins, contigs_in_bins, contig_to_length = parse_input_files(args.bin_dirs, args.contig2bin_tables, args.contigs, fasta_extensions=set(args.fasta_extensions)) + bin_set_name_to_bins, original_bins, contigs_in_bins, contig_to_length = ( + parse_input_files( + args.bin_dirs, + args.contig2bin_tables, + args.contigs, + fasta_extensions=set(args.fasta_extensions), + ) + ) + + contig_to_kegg_counter, contig_to_genes = manage_protein_alignement( + faa_file=faa_file, + contigs_fasta=args.contigs, + contig_to_length=contig_to_length, + contigs_in_bins=contigs_in_bins, + diamond_result_file=diamond_result_file, + checkm2_db=args.checkm2_db, + threads=args.threads, + use_existing_protein_file=use_existing_protein_file, + resume_diamond=args.resume, + low_mem=args.low_mem, + ) - contig_to_kegg_counter, contig_to_genes = manage_protein_alignement(faa_file=faa_file, contigs_fasta=args.contigs, contig_to_length=contig_to_length, - contigs_in_bins=contigs_in_bins, - diamond_result_file=diamond_result_file, checkm2_db=args.checkm2_db, - threads=args.threads, use_existing_protein_file=use_existing_protein_file, - resume_diamond=args.resume, low_mem=args.low_mem) - # Use contig index instead of contig name to save memory contig_to_index, index_to_contig = contig_manager.make_contig_index(contigs_in_bins) - contig_to_kegg_counter = contig_manager.apply_contig_index(contig_to_index, contig_to_kegg_counter) - contig_to_genes = contig_manager.apply_contig_index(contig_to_index, contig_to_genes) - contig_to_length = contig_manager.apply_contig_index(contig_to_index, contig_to_length) + contig_to_kegg_counter = contig_manager.apply_contig_index( + contig_to_index, contig_to_kegg_counter + ) + contig_to_genes = contig_manager.apply_contig_index( + contig_to_index, contig_to_genes + ) + contig_to_length = contig_manager.apply_contig_index( + contig_to_index, contig_to_length + ) bin_manager.rename_bin_contigs(original_bins, contig_to_index) - # Extract cds metadata ## logging.info("Compute cds metadata.") contig_metadat = cds.get_contig_cds_metadata(contig_to_genes, args.threads) contig_metadat["contig_to_kegg_counter"] = contig_to_kegg_counter contig_metadat["contig_to_length"] = contig_to_length - logging.info("Add size and assess quality of input bins") - bin_quality.add_bin_metrics(original_bins, contig_metadat, args.contamination_weight, args.threads) - - + bin_quality.add_bin_metrics( + original_bins, contig_metadat, args.contamination_weight, args.threads + ) - logging.info(f"Writting original input bin metrics to directory: {original_bin_report_dir}") + logging.info( + f"Writting original input bin metrics to directory: {original_bin_report_dir}" + ) io.write_original_bin_metrics(bin_set_name_to_bins, original_bin_report_dir) - logging.info("Create intermediate bins:") new_bins = bin_manager.create_intermediate_bins(bin_set_name_to_bins) logging.info("Assess quality for supplementary intermediate bins.") - new_bins = bin_quality.add_bin_metrics(new_bins, contig_metadat, args.contamination_weight, args.threads) - + new_bins = bin_quality.add_bin_metrics( + new_bins, contig_metadat, args.contamination_weight, args.threads + ) logging.info("Dereplicating input bins and new bins") all_bins = original_bins | new_bins - selected_bins = select_bins_and_write_them(all_bins, args.contigs, final_bin_report, args.min_completeness, index_to_contig, args.outdir, args.debug) + selected_bins = select_bins_and_write_them( + all_bins, + args.contigs, + final_bin_report, + args.min_completeness, + index_to_contig, + args.outdir, + args.debug, + ) log_selected_bin_info(selected_bins, hq_min_completeness, hq_max_conta) - return 0 \ No newline at end of file + return 0 diff --git a/docs/conf.py b/docs/conf.py index 0cce3b6..662bd9f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -8,9 +8,9 @@ from binette import __version__ -project = 'Binette' -copyright = '2024, Jean Mainguy' -author = 'Jean Mainguy' +project = "Binette" +copyright = "2024, Jean Mainguy" +author = "Jean Mainguy" release = __version__ @@ -18,20 +18,20 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration extensions = [ - # + # # "sphinxcontrib.jquery", "sphinx.ext.duration", "sphinx.ext.autosectionlabel", "sphinx.ext.autodoc", - 'sphinx_search.extension', - 'sphinx_togglebutton', + "sphinx_search.extension", + "sphinx_togglebutton", # "myst_nb", "myst_parser", - 'nbsphinx', - 'nbsphinx_link', + "nbsphinx", + "nbsphinx_link", # 'sphinx.ext.napoleon', # 'sphinx.ext.viewcode', - 'sphinxcontrib.mermaid' + "sphinxcontrib.mermaid", ] myst_enable_extensions = [ "amsmath", @@ -49,35 +49,31 @@ ] source_suffix = { - '.md': 'markdown', + ".md": "markdown", } -templates_path = ['_templates'] +templates_path = ["_templates"] nb_execution_mode = "off" -nbsphinx_execute = 'never' +nbsphinx_execute = "never" # Prefix document path to section labels, to use: # `path/to/file:heading` instead of just `heading` autosectionlabel_prefix_document = True -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'build', "jupyter_execute"] - +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "build", "jupyter_execute"] # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -# html_theme = 'sphinx_rtd_theme' #'alabaster' # -html_theme = 'sphinx_rtd_theme' #'sphinx_book_theme' +# html_theme = 'sphinx_rtd_theme' #'alabaster' # +html_theme = "sphinx_rtd_theme" #'sphinx_book_theme' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - - - +html_static_path = ["_static"] # Include the Plotly JavaScript in the HTML output @@ -85,18 +81,12 @@ # Ensures that the `require.js` is loaded for Plotly to function correctly nbsphinx_requirejs_options = { - 'paths': { - 'plotly': 'https://cdn.plot.ly/plotly-latest.min' - }, - 'shim': { - 'plotly': { - 'exports': 'Plotly' - } - } + "paths": {"plotly": "https://cdn.plot.ly/plotly-latest.min"}, + "shim": {"plotly": {"exports": "Plotly"}}, } # Specify the default language for syntax highlighting in Sphinx -highlight_language = 'python' +highlight_language = "python" # -- Options for HTML output ------------------------------------------------- @@ -104,7 +94,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add plotly renderer options nbsphinx_prolog = r""" @@ -112,6 +102,3 @@ """ - - - diff --git a/docs/contributing.md b/docs/contributing.md index e9b69ee..1ffcd71 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -28,20 +28,37 @@ For minor changes like fixing typos or making small edits, create a new Pull Req 2. **Get an Environment:** Create an environment with all Binette prerequisites installed by following the installation instructions [here](./installation.md#from-the-source-code-within-a-conda-environnement). -3. **Install in Editable Mode:** - To enable code editing and testing of new functionality, you can install Binette in editable mode using the following command: +3. **Branch from 'dev':** + Start your changes from the `dev` branch, where updates for the upcoming release are integrated. ```bash - pip install -e . + git checkout dev ``` - This allows you to modify the code and experiment with new features directly. - - -```{note} -Currently, we are not utilizing any auto formatters (like autopep8 or black). Kindly refrain from using them, as it could introduce extensive changes across the project, making code review challenging for us. -``` - +4. **Install in Editable Mode:** + To enable code editing and testing of new functionality, you can install Binette in editable mode: + ```bash + pip install -e .[dev] + ``` + - The `[dev]` part installs additional packages required for development. While not mandatory, it is recommended for an optimal development environment. Refer to the `pyproject.toml` file for details on the additional dependencies that will be installed. + + Installing in editable mode allows you to modify the codebase and experiment with new features directly. + +5. **Apply Code Formatting with Black:** + To maintain consistent code styling, we use [Black](https://github.com/psf/black) as our code formatter. + - Code changes are automatically checked for formatting as part of our CI pipeline via a GitHub Action. + - **Ensure your code is formatted with Black before committing.** + + To format your code: + 1. Install Black (e.g., `pip install black`). + 2. Run Black from the root of the Binette repository: + ```bash + black . + ``` + + ```{tip} + Configure your IDE to integrate Black for automatic code formatting as you work. + ``` ### Making Your Changes diff --git a/pyproject.toml b/pyproject.toml index 2769cd9..f0820f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,8 @@ doc = [ dev = [ "pytest>=7.0.0", - "pytest-cov" + "pytest-cov", + "black" ] # [project.urls] diff --git a/tests/bin_manager_test.py b/tests/bin_manager_test.py index f939b53..cac472d 100644 --- a/tests/bin_manager_test.py +++ b/tests/bin_manager_test.py @@ -11,6 +11,7 @@ import logging from pathlib import Path + def test_get_all_possible_combinations(): input_list = ["2", "3", "4"] expected_list = [("2", "3"), ("2", "4"), ("3", "4"), ("2", "3", "4")] @@ -25,12 +26,14 @@ def example_bin_set1(): bin3 = bin_manager.Bin(contigs={"5"}, origin="test1", name="bin2") return {bin1, bin2, bin3} + @pytest.fixture def example_bin_set2(): bin1 = bin_manager.Bin(contigs={"1", "2", "3"}, origin="test2", name="binA") bin2 = bin_manager.Bin(contigs={"4", "5"}, origin="test2", name="binB") return {bin1, bin2} + def test_bin_eq_true(): bin1 = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") bin2 = bin_manager.Bin(contigs={"1", "e", "2"}, origin="test2", name="binA") @@ -56,48 +59,49 @@ def test_in_for_bin_list(): assert bin2 in bins assert bin3 not in bins + def test_add_length_positive_integer(): bin_obj = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") length = 100 bin_obj.add_length(length) assert bin_obj.length == length + def test_add_length_negative_integer(): - bin_obj = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") + bin_obj = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") with pytest.raises(ValueError): length = -50 bin_obj.add_length(length) + def test_add_n50_positive_integer(): bin_obj = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") n50 = 100 bin_obj.add_N50(n50) assert bin_obj.N50 == n50 + def test_add_n50_negative_integer(): - bin_obj = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") + bin_obj = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") with pytest.raises(ValueError): n50 = -50 bin_obj.add_N50(n50) + def test_add_quality(): completeness = 10 contamination = 6 contamination_weight = 2 - bin_obj = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") + bin_obj = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") bin_obj.add_quality(completeness, contamination, contamination_weight) - - assert bin_obj.completeness == completeness - assert bin_obj.contamination == contamination - - assert bin_obj.score == completeness - contamination * contamination_weight - + assert bin_obj.completeness == completeness + assert bin_obj.contamination == contamination - + assert bin_obj.score == completeness - contamination * contamination_weight # def test_two_bin_intersection(): @@ -145,7 +149,9 @@ def test_bin_union(): bin1 = bin_manager.Bin(contigs={"13", "21"}, origin="test1", name="bin1") bin2 = bin_manager.Bin(contigs={"1", "e", "2", "33"}, origin="test2", name="binA") - expected_union_bin = bin_manager.Bin(contigs={"13", "21", "1", "e", "2", "33"}, origin="", name="") + expected_union_bin = bin_manager.Bin( + contigs={"13", "21", "1", "e", "2", "33"}, origin="", name="" + ) union_bin = bin1.union(bin2) assert union_bin == expected_union_bin @@ -154,16 +160,16 @@ def test_bin_union(): def test_bin_union2(): # Create some example bins - bin1 = bin_manager.Bin({'contig1', 'contig2'}, 'origin1', 'bin1') - bin2 = bin_manager.Bin({'contig2', 'contig3'}, 'origin2', 'bin2') - bin3 = bin_manager.Bin({'contig4', 'contig5'}, 'origin3', 'bin3') + bin1 = bin_manager.Bin({"contig1", "contig2"}, "origin1", "bin1") + bin2 = bin_manager.Bin({"contig2", "contig3"}, "origin2", "bin2") + bin3 = bin_manager.Bin({"contig4", "contig5"}, "origin3", "bin3") # Perform union operation union_bin = bin1.union(bin2, bin3) # Check the result - expected_contigs = {'contig1', 'contig2', 'contig3', 'contig4', 'contig5'} - expected_origin = {'union'} + expected_contigs = {"contig1", "contig2", "contig3", "contig4", "contig5"} + expected_origin = {"union"} assert union_bin.contigs == expected_contigs assert union_bin.origin == expected_origin @@ -267,31 +273,35 @@ def test_intersection_bins_created(): set2 = [ binA, ] - bin_set_name_to_bins = {'set1': set1, - "set2":set2} - - intermediate_bins_result = bin_manager.create_intermediate_bins(bin_set_name_to_bins) - - expected_intermediate_bins = {bin_manager.Bin(contigs={"1", "2", "3"}, origin="bin1 | binA ", name="NA"), - bin_manager.Bin(contigs={"2"}, origin="bin1 - binA ", name="NA"), - bin_manager.Bin(contigs={"1"}, origin="bin1 & binA ", name="NA"), - bin_manager.Bin(contigs={"1", "4", "3"}, origin="bin2 | binA ", name="NA"), - bin_manager.Bin(contigs={"4"}, origin="bin2 - binA ", name="NA"), # binA -bin1 is equal to bin1 & binA - bin_manager.Bin(contigs={"3"}, origin="bin1 & binA ", name="NA"), - } + bin_set_name_to_bins = {"set1": set1, "set2": set2} + + intermediate_bins_result = bin_manager.create_intermediate_bins( + bin_set_name_to_bins + ) + + expected_intermediate_bins = { + bin_manager.Bin(contigs={"1", "2", "3"}, origin="bin1 | binA ", name="NA"), + bin_manager.Bin(contigs={"2"}, origin="bin1 - binA ", name="NA"), + bin_manager.Bin(contigs={"1"}, origin="bin1 & binA ", name="NA"), + bin_manager.Bin(contigs={"1", "4", "3"}, origin="bin2 | binA ", name="NA"), + bin_manager.Bin( + contigs={"4"}, origin="bin2 - binA ", name="NA" + ), # binA -bin1 is equal to bin1 & binA + bin_manager.Bin(contigs={"3"}, origin="bin1 & binA ", name="NA"), + } assert intermediate_bins_result == expected_intermediate_bins # Renames contigs in bins based on provided mapping. def test_renames_contigs(example_bin_set1): - + bin_set = [ bin_manager.Bin(contigs={"c1", "c2"}, origin="A", name="bin1"), - bin_manager.Bin(contigs={"c3", "c4"}, origin="A", name="bin2") + bin_manager.Bin(contigs={"c3", "c4"}, origin="A", name="bin2"), ] - - contig_to_index = {'c1': 1, 'c2': 2, 'c3': 3, 'c4': 4, "c5":5} + + contig_to_index = {"c1": 1, "c2": 2, "c3": 3, "c4": 4, "c5": 5} # Act bin_manager.rename_bin_contigs(bin_set, contig_to_index) @@ -302,11 +312,12 @@ def test_renames_contigs(example_bin_set1): assert bin_set[1].contigs == {3, 4} assert bin_set[1].hash == hash(str(sorted({3, 4}))) + def test_get_contigs_in_bins(): bin_set = [ bin_manager.Bin(contigs={"c1", "c2"}, origin="A", name="bin1"), bin_manager.Bin(contigs={"c3", "c4"}, origin="A", name="bin2"), - bin_manager.Bin(contigs={"c3", "c18"}, origin="A", name="bin2") + bin_manager.Bin(contigs={"c3", "c18"}, origin="A", name="bin2"), ] contigs = bin_manager.get_contigs_in_bins(bin_set) @@ -319,29 +330,26 @@ def test_dereplicate_bin_sets(): b2 = bin_manager.Bin(contigs={"c3", "c4"}, origin="A", name="bin2") b3 = bin_manager.Bin(contigs={"c3", "c18"}, origin="A", name="bin2") - bdup = bin_manager.Bin(contigs={"c3", "c18"}, origin="D", name="C") - - derep_bins_result = bin_manager.dereplicate_bin_sets([[b2,b3 ],[b1, bdup]]) + bdup = bin_manager.Bin(contigs={"c3", "c18"}, origin="D", name="C") + derep_bins_result = bin_manager.dereplicate_bin_sets([[b2, b3], [b1, bdup]]) assert derep_bins_result == {b1, b2, b3} - def test_from_bin_sets_to_bin_graph(): - bin1 = bin_manager.Bin(contigs={"1", "2"}, origin="A", name="bin1") bin2 = bin_manager.Bin(contigs={"3", "4"}, origin="A", name="bin2") bin3 = bin_manager.Bin(contigs={"5"}, origin="A", name="bin3") - set1 = [bin1,bin2,bin3] + set1 = [bin1, bin2, bin3] binA = bin_manager.Bin(contigs={"1", "3"}, origin="B", name="binA") - + set2 = [binA] - result_graph = bin_manager.from_bin_sets_to_bin_graph({"B":set2, "A":set1}) + result_graph = bin_manager.from_bin_sets_to_bin_graph({"B": set2, "A": set1}) assert result_graph.number_of_edges() == 2 # bin3 is not connected to any bin so it is not in the graph @@ -349,6 +357,7 @@ def test_from_bin_sets_to_bin_graph(): assert set(result_graph.nodes) == {binA, bin1, bin2} + @pytest.fixture def simple_bin_graph(): @@ -358,7 +367,7 @@ def simple_bin_graph(): for b in [bin1, bin2]: b.completeness = 100 b.contamination = 0 - + G = nx.Graph() G.add_edge(bin1, bin2) @@ -368,35 +377,34 @@ def simple_bin_graph(): def test_get_intersection_bins(simple_bin_graph): intersec_bins = bin_manager.get_intersection_bins(simple_bin_graph) - + assert len(intersec_bins) == 1 intersec_bin = intersec_bins.pop() assert intersec_bin.contigs == {"1", "2"} + def test_get_difference_bins(simple_bin_graph): - + difference_bins = bin_manager.get_difference_bins(simple_bin_graph) - + expected_bin1 = bin_manager.Bin(contigs={"3"}, origin="D", name="1") expected_bin2 = bin_manager.Bin(contigs={"4"}, origin="D", name="2") assert len(difference_bins) == 2 - assert difference_bins == {expected_bin1,expected_bin2} + assert difference_bins == {expected_bin1, expected_bin2} def test_get_union_bins(simple_bin_graph): - + u_bins = bin_manager.get_union_bins(simple_bin_graph) - + expected_bin1 = bin_manager.Bin(contigs={"1", "2", "3", "4"}, origin="U", name="1") assert len(u_bins) == 1 assert u_bins == {expected_bin1} - - def test_get_bins_from_contig2bin_table(tmp_path): # Create a temporary file (contig-to-bin table) for testing test_table_content = [ @@ -412,15 +420,17 @@ def test_get_bins_from_contig2bin_table(tmp_path): set_name = "TestSet" # Call the function to generate Bin objects - result_bins = bin_manager.get_bins_from_contig2bin_table(str(test_table_path), set_name) + result_bins = bin_manager.get_bins_from_contig2bin_table( + str(test_table_path), set_name + ) # Validate the result assert len(result_bins) == 2 # Check if the correct number of bins are created # Define expected bins based on the test table content expected_bins = [ - bin_manager.Bin(contigs={"contig1", "contig2"}, origin="A", name="bin1"), - bin_manager.Bin(contigs={"contig3"}, origin="A", name="bin2") + bin_manager.Bin(contigs={"contig1", "contig2"}, origin="A", name="bin1"), + bin_manager.Bin(contigs={"contig3"}, origin="A", name="bin2"), ] # Compare expected bins with the result @@ -440,8 +450,8 @@ def test_parse_contig2bin_tables(tmp_path): "set2": [ "# Sample contig-to-bin table for bin2", "contig3\tbinA", - "contig4\tbinA" - ] + "contig4\tbinA", + ], } # Create temporary files for contig-to-bin tables @@ -450,10 +460,17 @@ def test_parse_contig2bin_tables(tmp_path): table_path.write_text("\n".join(content)) # Call the function to parse contig-to-bin tables - result_bin_dict = bin_manager.parse_contig2bin_tables({name: str(tmp_path / f"test_{name}_contig2bin_table.txt") for name in test_tables}) + result_bin_dict = bin_manager.parse_contig2bin_tables( + { + name: str(tmp_path / f"test_{name}_contig2bin_table.txt") + for name in test_tables + } + ) # Validate the result - assert len(result_bin_dict) == len(test_tables) # Check if the number of bins matches the number of tables + assert len(result_bin_dict) == len( + test_tables + ) # Check if the number of bins matches the number of tables # Define expected Bin objects based on the test tables expected_bins = { @@ -463,7 +480,7 @@ def test_parse_contig2bin_tables(tmp_path): ], "set2": [ bin_manager.Bin(contigs={"contig3", "contig4"}, origin="set2", name="binA"), - ] + ], } # Compare expected bins with the result @@ -492,10 +509,17 @@ def test_parse_contig2bin_tables_with_duplicated_bins(tmp_path, caplog): table_path.write_text("\n".join(content)) # Call the function to parse contig-to-bin tables - bin_manager.parse_contig2bin_tables({name: str(tmp_path / f"test_{name}_contig2bin_table.txt") for name in test_tables}) - expected_log_message = ('2 bins with identical contig compositions detected in bin set "set1". ' - 'These bins were merged to ensure uniqueness.') - assert expected_log_message in caplog.text + bin_manager.parse_contig2bin_tables( + { + name: str(tmp_path / f"test_{name}_contig2bin_table.txt") + for name in test_tables + } + ) + expected_log_message = ( + '2 bins with identical contig compositions detected in bin set "set1". ' + "These bins were merged to ensure uniqueness." + ) + assert expected_log_message in caplog.text @pytest.fixture @@ -510,6 +534,7 @@ def create_temp_bin_files(tmpdir): return bin_dir + @pytest.fixture def create_temp_bin_directories(tmpdir, create_temp_bin_files): # Create temporary bin directories @@ -520,7 +545,6 @@ def create_temp_bin_directories(tmpdir, create_temp_bin_files): bin2 = bin_dir1.join("bin2.fasta") bin2.write(">contig3\nTTAG\n>contig4\nCGAT") - bin_dir2 = tmpdir.mkdir("set2") bin2 = bin_dir2.join("binA.fasta") bin2.write(">contig3\nTTAG\n>contig4\nCGAT\n>contig5\nCGGC") @@ -532,7 +556,9 @@ def test_get_bins_from_directory(create_temp_bin_files): bin_dir = create_temp_bin_files set_name = "TestSet" - bins = bin_manager.get_bins_from_directory(Path(bin_dir), set_name, fasta_extensions={'.fasta'}) + bins = bin_manager.get_bins_from_directory( + Path(bin_dir), set_name, fasta_extensions={".fasta"} + ) assert len(bins) == 2 # Ensure that the correct number of Bin objects is returned @@ -546,30 +572,39 @@ def test_get_bins_from_directory(create_temp_bin_files): assert bins[1].name in ["bin2.fasta", "bin1.fasta"] assert bins[0].name in ["bin2.fasta", "bin1.fasta"] + def test_get_bins_from_directory_no_files(tmpdir): bin_dir = Path(tmpdir.mkdir("empty_bins")) set_name = "EmptySet" - bins = bin_manager.get_bins_from_directory(bin_dir, set_name, fasta_extensions={'.fasta'}) + bins = bin_manager.get_bins_from_directory( + bin_dir, set_name, fasta_extensions={".fasta"} + ) + + assert ( + len(bins) == 0 + ) # Ensure that no Bin objects are returned for an empty directory - assert len(bins) == 0 # Ensure that no Bin objects are returned for an empty directory def test_get_bins_from_directory_no_wrong_extensions(create_temp_bin_files): bin_dir = Path(create_temp_bin_files) set_name = "TestSet" - bins = bin_manager.get_bins_from_directory(bin_dir, set_name, fasta_extensions={'.fna'}) - - assert len(bins) == 0 # Ensure that no Bin objects are returned for an empty directory - - + bins = bin_manager.get_bins_from_directory( + bin_dir, set_name, fasta_extensions={".fna"} + ) + assert ( + len(bins) == 0 + ) # Ensure that no Bin objects are returned for an empty directory def test_parse_bin_directories(create_temp_bin_directories): set_name_to_bin_dir = create_temp_bin_directories - bins = bin_manager.parse_bin_directories(set_name_to_bin_dir, fasta_extensions={'.fasta'}) + bins = bin_manager.parse_bin_directories( + set_name_to_bin_dir, fasta_extensions={".fasta"} + ) assert len(bins) == 2 # Ensure that the correct number of bin directories is parsed @@ -584,44 +619,46 @@ def test_parse_bin_directories(create_temp_bin_directories): def test_get_contigs_in_bin_sets(example_bin_set1, example_bin_set2, caplog): """ Test the get_contigs_in_bin_sets function for correct behavior. - + :param mock_bins: The mock_bins fixture providing test bin data. :param caplog: The pytest caplog fixture to capture logging output. """ - bin_set_name_to_bins = {"set1":example_bin_set1, - "set2":example_bin_set2} + bin_set_name_to_bins = {"set1": example_bin_set1, "set2": example_bin_set2} # Test the function with valid data with caplog.at_level(logging.WARNING): result = bin_manager.get_contigs_in_bin_sets(bin_set_name_to_bins) - + # Expected unique contigs expected_contigs = {"1", "2", "3", "4", "5"} - + # Check if the result matches expected contigs assert result == expected_contigs, "The returned set of contigs is incorrect." - + + def test_get_contigs_in_bin_sets_with_duplicated_warning(example_bin_set1, caplog): bin1 = bin_manager.Bin(contigs={"contig1", "2"}, origin="test1", name="bin1") bin2 = bin_manager.Bin(contigs={"contig1"}, origin="test1", name="binA") bin_set_name_to_bins = { - "set1":example_bin_set1, - "set_dup":{bin1, bin2}, - } + "set1": example_bin_set1, + "set_dup": {bin1, bin2}, + } # Test the function with valid data with caplog.at_level(logging.WARNING): result = bin_manager.get_contigs_in_bin_sets(bin_set_name_to_bins) - + # Expected unique contigs expected_contigs = {"1", "2", "3", "4", "5", "contig1"} - + # Check if the result matches expected contigs assert result == expected_contigs, "The returned set of contigs is incorrect." # Check for expected warnings about duplicate contigs duplicate_warning = "Bin set 'set_dup' contains 1 duplicated contigs. Details: contig1 (found 2 times)" - assert duplicate_warning in caplog.text, "The warning for duplicate contigs was not logged correctly." + assert ( + duplicate_warning in caplog.text + ), "The warning for duplicate contigs was not logged correctly." diff --git a/tests/bin_quality_test.py b/tests/bin_quality_test.py index 838086a..d1ce3a6 100644 --- a/tests/bin_quality_test.py +++ b/tests/bin_quality_test.py @@ -11,14 +11,16 @@ add_bin_metrics, assess_bins_quality_by_chunk, assess_bins_quality, - chunks, - get_diamond_feature_per_bin_df, - get_bins_metadata_df) + chunks, + get_diamond_feature_per_bin_df, + get_bins_metadata_df, +) from checkm2 import keggData, modelPostprocessing, modelProcessing from unittest.mock import Mock, patch, MagicMock + def test_compute_N50(): assert bin_quality.compute_N50([50]) == 50 assert bin_quality.compute_N50([0]) == 0 @@ -27,7 +29,6 @@ def test_compute_N50(): assert bin_quality.compute_N50([1, 3, 3, 4, 5, 5, 6, 9, 10, 24]) == 9 - def test_chunks(): # Test case 1 iterable_1 = [1, 2, 3, 4, 5, 6] @@ -60,7 +61,6 @@ def test_chunks(): result_4 = list(chunks(iterable_4, size_4)) assert result_4 == expected_output_4 - class Bin: @@ -78,38 +78,67 @@ def add_N50(self, N50): def add_N50(self, N50): self.N50 = N50 - + def add_quality(self, comp, cont, weight): self.completeness = comp self.contamination = cont self.score = comp - weight * cont - + + def test_get_bins_metadata_df(): # Mock input data - bins = [ - Bin(1, ['contig1', 'contig3']), - Bin(2, ['contig2']) - ] + bins = [Bin(1, ["contig1", "contig3"]), Bin(2, ["contig2"])] - contig_to_cds_count = {'contig1': 10, 'contig2': 45, 'contig3': 20, 'contig4': 25} - contig_to_aa_counter = {'contig1': Counter({'A': 5, 'D': 10}), 'contig2': Counter({'G': 8, 'V': 12, 'T': 2}), - 'contig3': Counter({'D': 8, 'Y': 12})} - contig_to_aa_length = {'contig1': 1000, 'contig2': 1500, 'contig3': 2000, 'contig4': 2500} + contig_to_cds_count = {"contig1": 10, "contig2": 45, "contig3": 20, "contig4": 25} + contig_to_aa_counter = { + "contig1": Counter({"A": 5, "D": 10}), + "contig2": Counter({"G": 8, "V": 12, "T": 2}), + "contig3": Counter({"D": 8, "Y": 12}), + } + contig_to_aa_length = { + "contig1": 1000, + "contig2": 1500, + "contig3": 2000, + "contig4": 2500, + } # Call the function - result_df = bin_quality.get_bins_metadata_df(bins, contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length) + result_df = bin_quality.get_bins_metadata_df( + bins, contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length + ) # Define expected values based on the provided input expected_columns = [ - 'Name', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', - 'AALength', 'CDS' + "Name", + "A", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "K", + "L", + "M", + "N", + "P", + "Q", + "R", + "S", + "T", + "V", + "W", + "Y", + "AALength", + "CDS", ] expected_index = [1, 2] expected_values = [ [1, 5, 0, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 3000, 30], - [2, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 12, 0, 0, 1500, 45] + [2, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 12, 0, 0, 1500, 45], ] # Check if the generated DataFrame matches the expected DataFrame @@ -118,46 +147,38 @@ def test_get_bins_metadata_df(): assert result_df.values.tolist() == expected_values - - def test_get_diamond_feature_per_bin_df(): # Mock input data - bins = [ - Bin(1, ['contig1', 'contig2']), - Bin(2, ['contig3', 'contig4']) - ] + bins = [Bin(1, ["contig1", "contig2"]), Bin(2, ["contig3", "contig4"])] contig_to_kegg_counter = { - 'contig1': Counter({'K01810': 5, 'K15916': 7}), - 'contig2': Counter({'K01810': 10}), - 'contig3': Counter({'K00918': 8}), + "contig1": Counter({"K01810": 5, "K15916": 7}), + "contig2": Counter({"K01810": 10}), + "contig3": Counter({"K00918": 8}), } # Call the function - result_df, default_ko_count = bin_quality.get_diamond_feature_per_bin_df(bins, contig_to_kegg_counter) + result_df, default_ko_count = bin_quality.get_diamond_feature_per_bin_df( + bins, contig_to_kegg_counter + ) expected_index = [1, 2] - - - assert result_df.index.tolist() == expected_index - assert result_df.loc[1,"K01810"] == 15 # in bin1 from contig 1 and 2 - assert result_df.loc[1,"K15916"] == 7 # in bin1 from contig 1 - assert result_df.loc[2,"K01810"] == 0 # this ko is not in any contig of bin 2 - assert result_df.loc[2,"K00918"] == 8 # in bin2 from contig 3 + assert result_df.index.tolist() == expected_index + assert result_df.loc[1, "K01810"] == 15 # in bin1 from contig 1 and 2 + assert result_df.loc[1, "K15916"] == 7 # in bin1 from contig 1 + assert result_df.loc[2, "K01810"] == 0 # this ko is not in any contig of bin 2 + assert result_df.loc[2, "K00918"] == 8 # in bin2 from contig 3 def test_add_bin_size_and_N50(): # Mock input data - bins = [ - Bin(1, ['contig1', 'contig2']), - Bin(2, ['contig3']) - ] + bins = [Bin(1, ["contig1", "contig2"]), Bin(2, ["contig3"])] contig_to_size = { - 'contig1': 1000, - 'contig2': 1500, - 'contig3': 2000, + "contig1": 1000, + "contig2": 1500, + "contig3": 2000, } # Call the function @@ -165,47 +186,46 @@ def test_add_bin_size_and_N50(): # Assertions to verify if add_length and add_N50 were called with the correct values assert bins[0].length == 2500 - assert bins[0].N50 == 1500 + assert bins[0].N50 == 1500 assert bins[1].length == 2000 - assert bins[1].N50 == 2000 - - + assert bins[1].N50 == 2000 def mock_modelProcessor(thread): return "mock_modelProcessor" + def test_add_bin_metrics(monkeypatch): # Mock input data - bins = [ - Bin(1, ['contig1', 'contig2']), - Bin(2, ['contig3']) - ] + bins = [Bin(1, ["contig1", "contig2"]), Bin(2, ["contig3"])] contig_info = { # Add mocked contig information here as needed - "contig_to_kegg_counter":{}, - "contig_to_cds_count":{}, - "contig_to_aa_counter":{}, - "contig_to_aa_length":{}, - "contig_to_length":{}, + "contig_to_kegg_counter": {}, + "contig_to_cds_count": {}, + "contig_to_aa_counter": {}, + "contig_to_aa_length": {}, + "contig_to_length": {}, } contamination_weight = 0.5 threads = 1 - monkeypatch.setattr(modelPostprocessing, "modelProcessor", mock_modelProcessor) - # Mock the functions called within add_bin_metrics - with patch('binette.bin_quality.add_bin_size_and_N50') as mock_add_bin_size_and_N50, \ - patch('binette.bin_quality.assess_bins_quality_by_chunk') as mock_assess_bins_quality_by_chunk: + with patch( + "binette.bin_quality.add_bin_size_and_N50" + ) as mock_add_bin_size_and_N50, patch( + "binette.bin_quality.assess_bins_quality_by_chunk" + ) as mock_assess_bins_quality_by_chunk: add_bin_metrics(bins, contig_info, contamination_weight, threads) # Assertions to check if functions were called with the expected arguments - mock_add_bin_size_and_N50.assert_called_once_with(bins, contig_info["contig_to_length"]) + mock_add_bin_size_and_N50.assert_called_once_with( + bins, contig_info["contig_to_length"] + ) mock_assess_bins_quality_by_chunk.assert_called_once_with( bins, contig_info["contig_to_kegg_counter"], @@ -216,72 +236,71 @@ def test_add_bin_metrics(monkeypatch): "mock_modelProcessor", # Mocked postProcessor object ) + def test_assess_bins_quality_by_chunk(monkeypatch): # Prepare input data for testing bins = [ - Bin(1, ['contig1', 'contig2']), - Bin(2, ['contig3', 'contig4']), - Bin(3, ['contig3', 'contig4']) + Bin(1, ["contig1", "contig2"]), + Bin(2, ["contig3", "contig4"]), + Bin(3, ["contig3", "contig4"]), ] contig_to_kegg_counter = {} contig_to_cds_count = {} contig_to_aa_counter = {} - contig_to_aa_length = {} - contamination_weight = 0.5 + contig_to_aa_length = {} + contamination_weight = 0.5 # Mocking postProcessor object monkeypatch.setattr(modelPostprocessing, "modelProcessor", mock_modelProcessor) - # Mock the functions called within add_bin_metrics - with patch('binette.bin_quality.assess_bins_quality') as mock_assess_bins_quality: + with patch("binette.bin_quality.assess_bins_quality") as mock_assess_bins_quality: assess_bins_quality_by_chunk( - bins, - contig_to_kegg_counter, - contig_to_cds_count, - contig_to_aa_counter, - contig_to_aa_length, - contamination_weight, - postProcessor=None, - threads=1, - chunk_size = 3 - ) + bins, + contig_to_kegg_counter, + contig_to_cds_count, + contig_to_aa_counter, + contig_to_aa_length, + contamination_weight, + postProcessor=None, + threads=1, + chunk_size=3, + ) # Chunk size > number of bin so only one chunk mock_assess_bins_quality.assert_called_once_with( - bins= set(bins), - contig_to_kegg_counter= contig_to_kegg_counter, + bins=set(bins), + contig_to_kegg_counter=contig_to_kegg_counter, contig_to_cds_count=contig_to_cds_count, contig_to_aa_counter=contig_to_aa_counter, contig_to_aa_length=contig_to_aa_length, contamination_weight=contamination_weight, postProcessor=None, - threads=1 + threads=1, ) # Mock the functions called within add_bin_metrics - with patch('binette.bin_quality.assess_bins_quality') as mock_assess_bins_quality: + with patch("binette.bin_quality.assess_bins_quality") as mock_assess_bins_quality: assess_bins_quality_by_chunk( - bins, - contig_to_kegg_counter, - contig_to_cds_count, - contig_to_aa_counter, - contig_to_aa_length, - contamination_weight, - postProcessor=None, - threads=1, - chunk_size = 2 - ) + bins, + contig_to_kegg_counter, + contig_to_cds_count, + contig_to_aa_counter, + contig_to_aa_length, + contamination_weight, + postProcessor=None, + threads=1, + chunk_size=2, + ) # Chunk size < number of bin so 2 chunks with [bin1,bin2] and [bin3] assert mock_assess_bins_quality.call_count == 2 - from unittest.mock import patch, MagicMock import numpy as np import pandas as pd @@ -290,10 +309,7 @@ def test_assess_bins_quality_by_chunk(monkeypatch): def test_assess_bins_quality(): # Prepare mock input data for testing - bins = [ - Bin(1, ['contig1', 'contig2']), - Bin(2, ['contig3', 'contig4']) - ] + bins = [Bin(1, ["contig1", "contig2"]), Bin(2, ["contig3", "contig4"])] contig_to_kegg_counter = {} contig_to_cds_count = {} @@ -301,7 +317,6 @@ def test_assess_bins_quality(): contig_to_aa_counter = {} contamination_weight = 0.5 - # Call the function being tested assess_bins_quality( bins, @@ -309,14 +324,15 @@ def test_assess_bins_quality(): contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length, - contamination_weight + contamination_weight, ) - - # Verify the expected calls to add_quality for each bin object for bin_obj in bins: assert bin_obj.completeness is not None assert bin_obj.contamination is not None assert bin_obj.score is not None - assert bin_obj.score == bin_obj.completeness - bin_obj.contamination * contamination_weight \ No newline at end of file + assert ( + bin_obj.score + == bin_obj.completeness - bin_obj.contamination * contamination_weight + ) diff --git a/tests/cds_test.py b/tests/cds_test.py index 20eb787..97380fc 100644 --- a/tests/cds_test.py +++ b/tests/cds_test.py @@ -5,23 +5,31 @@ from pathlib import Path from unittest.mock import mock_open, patch + class MockContig: def __init__(self, name, seq): self.seq = seq self.name = name + @pytest.fixture def contig1(): - contig = MockContig(name="contig1", - seq="ATGAGCATCAGGGGAGTAGGAGGATGCAACGGGAATAGTCGAATCCCTTCTCATAATGGGGATGGATCGAATCGCAGAAGTCAAAATACGAAGGGTAATAATAAAGTTGAAGATCGAGTTTGT") + contig = MockContig( + name="contig1", + seq="ATGAGCATCAGGGGAGTAGGAGGATGCAACGGGAATAGTCGAATCCCTTCTCATAATGGGGATGGATCGAATCGCAGAAGTCAAAATACGAAGGGTAATAATAAAGTTGAAGATCGAGTTTGT", + ) return contig + @pytest.fixture def contig2(): - contig = MockContig(name="contig2", - seq="TTGGTCGTATGACTGATAATTTCTCAGACATTGAAAACTTTAATGAAATTTTCAACAGAAAACCTGCTTTACAATTTCGTTTTTA") + contig = MockContig( + name="contig2", + seq="TTGGTCGTATGACTGATAATTTCTCAGACATTGAAAACTTTAATGAAATTTTCAACAGAAAACCTGCTTTACAATTTCGTTTTTA", + ) return contig + @pytest.fixture def orf_finder(): try: @@ -32,6 +40,7 @@ def orf_finder(): return orf_finder + # Predict open reading frames with Pyrodigal using 1 thread. def test_predict_orf_with_1_thread(contig1, contig2): @@ -60,7 +69,7 @@ def test_predict_orf_with_multiple_threads(contig1, contig2): threads = 4 result = cds.predict(contigs_iterator, outfaa, threads) - + assert isinstance(result, dict) assert len(result) == 2 assert "contig1" in result @@ -75,7 +84,6 @@ def test_predict_orf_with_multiple_threads(contig1, contig2): def test_predict_genes(contig1, orf_finder): - result = cds.predict_genes(orf_finder.find_genes, contig1) assert isinstance(result, tuple) @@ -94,6 +102,7 @@ def test_extract_contig_name_from_cds_name(): assert isinstance(result, str) assert result == "contig1" + def test_extract_contig_name_from_cds_name(): cds_name = "contig1_gene1" @@ -104,9 +113,9 @@ def test_extract_contig_name_from_cds_name(): def test_write_faa(contig1, orf_finder): - + predicted_genes = orf_finder.find_genes(contig1.seq) - contig_name = 'contig' + contig_name = "contig" output_file = "tests/tmp_file.faa" cds.write_faa(output_file, [(contig_name, predicted_genes)]) @@ -123,7 +132,7 @@ def test_parse_faa_file(tmp_path): # at least one protein sequence to not triger the error fasta_content = ( ">contig1_gene1\n" - "MPPPAOSKNSKSS\n" + "MPPPAOSKNSKSS\n" ">contig1_gene2\n" "CCCCCCCCCCC\n" ">contig2_gene1\n" @@ -137,8 +146,8 @@ def test_parse_faa_file(tmp_path): # Check if the output matches the expected dictionary expected_result = { - 'contig1': ['MPPPAOSKNSKSS', 'CCCCCCCCCCC'], - 'contig2': ['TTTTTTTTTTTT'] + "contig1": ["MPPPAOSKNSKSS", "CCCCCCCCCCC"], + "contig2": ["TTTTTTTTTTTT"], } assert result == expected_result @@ -156,62 +165,64 @@ def test_parse_faa_file_raises_error_for_dna(tmp_path): fna_file = tmp_path / "mock_file.fna" fna_file.write_text(fasta_content) - # Check that ValueError is raised when DNA sequences are encountered with pytest.raises(ValueError): cds.parse_faa_file(fna_file) - def test_get_aa_composition(): - genes = ['AAAA', - "CCCC", - "TTTT", - "GGGG"] + genes = ["AAAA", "CCCC", "TTTT", "GGGG"] result = cds.get_aa_composition(genes) - assert dict(result) == {'A': 4, 'C': 4, 'T': 4, 'G': 4} + assert dict(result) == {"A": 4, "C": 4, "T": 4, "G": 4} + def test_get_contig_cds_metadata_flat(): - contig_to_genes = {"c1":["AAAA", "GGGG", "CCCC"], - "c2":["TTTT", "CCCC"]} + contig_to_genes = {"c1": ["AAAA", "GGGG", "CCCC"], "c2": ["TTTT", "CCCC"]} + + contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length = ( + cds.get_contig_cds_metadata_flat(contig_to_genes) + ) + + assert contig_to_cds_count == {"c1": 3, "c2": 2} + assert contig_to_aa_counter == { + "c1": {"A": 4, "G": 4, "C": 4}, + "c2": {"C": 4, "T": 4}, + } + assert contig_to_aa_length == {"c1": 12, "c2": 8} - contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length = cds.get_contig_cds_metadata_flat(contig_to_genes) - - assert contig_to_cds_count == {"c1":3, "c2":2} - assert contig_to_aa_counter == {"c1": {'A': 4, 'G': 4, "C":4} , "c2":{'C': 4, 'T': 4}} - assert contig_to_aa_length == {"c1":12, "c2":8} def test_get_contig_cds_metadata(): - contig_to_genes = {"c1":["AAAA", "GGGG", "CCCC"], - "c2":["TTTT", "CCCC"]} + contig_to_genes = {"c1": ["AAAA", "GGGG", "CCCC"], "c2": ["TTTT", "CCCC"]} contig_metadata = cds.get_contig_cds_metadata(contig_to_genes, 1) - - assert contig_metadata['contig_to_cds_count'] == {"c1":3, "c2":2} - assert contig_metadata['contig_to_aa_counter'] == {"c1": {'A': 4, 'G': 4, "C":4} , "c2":{'C': 4, 'T': 4}} - assert contig_metadata['contig_to_aa_length'] == {"c1":12, "c2":8} + assert contig_metadata["contig_to_cds_count"] == {"c1": 3, "c2": 2} + assert contig_metadata["contig_to_aa_counter"] == { + "c1": {"A": 4, "G": 4, "C": 4}, + "c2": {"C": 4, "T": 4}, + } + assert contig_metadata["contig_to_aa_length"] == {"c1": 12, "c2": 8} # Test function def test_is_nucleic_acid(): # Valid DNA sequence assert cds.is_nucleic_acid("ATCG") is True - assert cds.is_nucleic_acid("ATCNNNNNG") is True # N can be found in DNA seq + assert cds.is_nucleic_acid("ATCNNNNNG") is True # N can be found in DNA seq # Valid RNA sequence assert cds.is_nucleic_acid("AUGCAUGC") is True - + # Mixed case assert cds.is_nucleic_acid("AtCg") is True - + # Invalid sequence (contains characters not part of DNA or RNA) assert cds.is_nucleic_acid("ATCX") is False # 'X' is not a valid base assert cds.is_nucleic_acid("AUG#C") is False # '#' is not a valid base - + # Amino acid sequence assert cds.is_nucleic_acid("MSIRGVGGNGNSR") is False # Numbers are invalid diff --git a/tests/contig_manager_test.py b/tests/contig_manager_test.py index 16ed1f8..38aca8b 100644 --- a/tests/contig_manager_test.py +++ b/tests/contig_manager_test.py @@ -2,6 +2,7 @@ import pyfastx import pytest + # Parses a valid FASTA file and returns a pyfastx.Fasta object. def test_valid_fasta_file(): # Arrange @@ -13,28 +14,30 @@ def test_valid_fasta_file(): # Assert assert isinstance(result, pyfastx.Fasta) + # Parses an invalid FASTA file and raises an exception. def test_invalid_fasta_file(): - # + # fasta_file = "tests/contig_manager_test.py" # Act and Assert with pytest.raises(RuntimeError): contig_manager.parse_fasta_file(fasta_file) - + # The function returns a tuple containing two dictionaries. def test_returns_tuple(): - contigs = ['contig1', 'contig2', 'contig3'] + contigs = ["contig1", "contig2", "contig3"] result = contig_manager.make_contig_index(contigs) assert isinstance(result, tuple) assert len(result) == 2 assert isinstance(result[0], dict) assert isinstance(result[1], dict) + # The first dictionary maps contig names to their index. def test_contig_to_index_mapping(): - contigs = ['contig1', 'contig2', 'contig3'] + contigs = ["contig1", "contig2", "contig3"] result = contig_manager.make_contig_index(contigs) contig_to_index = result[0] assert isinstance(contig_to_index, dict) @@ -43,12 +46,13 @@ def test_contig_to_index_mapping(): assert contig in contig_to_index assert isinstance(contig_to_index[contig], int) + # The function returns a dictionary with the same number of items as the input dictionary. def test_same_number_of_items(): - contig_to_index = {'contig1': 0, 'contig2': 1, 'contig3': 2} - contig_to_info = {'contig1': 'info1', 'contig2': 'info2', 'contig3': 'info3'} - expected_result = {0: 'info1', 1: 'info2', 2: 'info3'} + contig_to_index = {"contig1": 0, "contig2": 1, "contig3": 2} + contig_to_info = {"contig1": "info1", "contig2": "info2", "contig3": "info3"} + expected_result = {0: "info1", 1: "info2", 2: "info3"} result = contig_manager.apply_contig_index(contig_to_index, contig_to_info) - assert len(result) == len(expected_result) \ No newline at end of file + assert len(result) == len(expected_result) diff --git a/tests/diamonds_test.py b/tests/diamonds_test.py index a2d23a3..eda2975 100644 --- a/tests/diamonds_test.py +++ b/tests/diamonds_test.py @@ -16,6 +16,7 @@ import pandas as pd from collections import Counter + class CompletedProcess: def __init__(self, returncode, stderr): self.returncode = returncode @@ -26,6 +27,7 @@ def mock_shutil_which(*args, **kwargs): if args[0] == "checkm2": return "checkm2" + def test_get_checkm2_db_no_checkm2(monkeypatch): # Mocking shutil.which def mock_shutil_which_none(*args, **kwargs): @@ -40,14 +42,19 @@ def mock_shutil_which_none(*args, **kwargs): assert pytest_wrapped_e.type == SystemExit assert pytest_wrapped_e.value.code == 1 + def test_get_checkm2_db_with_success(monkeypatch): def mock_subprocess_run(*args, **kwargs): # Simulating the behavior of checkm2 command - if args[0][0] == "checkm2" and args[0][1] == "database" and args[0][2] == "--current": + if ( + args[0][0] == "checkm2" + and args[0][1] == "database" + and args[0][2] == "--current" + ): return CompletedProcess(0, "INFO: /mocked/path/to/checkm2.dmnd") - + monkeypatch.setattr(subprocess, "run", mock_subprocess_run) monkeypatch.setattr(shutil, "which", mock_shutil_which) @@ -57,14 +64,18 @@ def mock_subprocess_run(*args, **kwargs): expected_path = "/mocked/path/to/checkm2.dmnd" assert result == expected_path + def test_get_checkm2_db_checkm2_exit_error(monkeypatch): def mock_subprocess_run(*args, **kwargs): # Simulating the behavior of checkm2 command - if args[0][0] == "checkm2" and args[0][1] == "database" and args[0][2] == "--current": + if ( + args[0][0] == "checkm2" + and args[0][1] == "database" + and args[0][2] == "--current" + ): return CompletedProcess(2, "") - monkeypatch.setattr(subprocess, "run", mock_subprocess_run) monkeypatch.setattr(shutil, "which", mock_shutil_which) @@ -76,14 +87,22 @@ def mock_subprocess_run(*args, **kwargs): assert pytest_wrapped_e.type == SystemExit assert pytest_wrapped_e.value.code == 1 + def test_get_checkm2_db_wrong_path_format(monkeypatch): def mock_subprocess_run(*args, **kwargs): # Simulating the behavior of checkm2 command - if args[0][0] == "checkm2" and args[0][1] == "database" and args[0][2] == "--current": - return CompletedProcess(0, "UNEXPECTED PATH FORMAT RETURNED BY CHECKM2: /mocked/path/to/checkm2.dmnd") - + if ( + args[0][0] == "checkm2" + and args[0][1] == "database" + and args[0][2] == "--current" + ): + return CompletedProcess( + 0, + "UNEXPECTED PATH FORMAT RETURNED BY CHECKM2: /mocked/path/to/checkm2.dmnd", + ) + monkeypatch.setattr(subprocess, "run", mock_subprocess_run) monkeypatch.setattr(shutil, "which", mock_shutil_which) @@ -95,7 +114,6 @@ def mock_subprocess_run(*args, **kwargs): assert pytest_wrapped_e.value.code == 1 - def test_check_tool_exists_tool_found(monkeypatch): # Mocking shutil.which def mock_shutil_which(*args, **kwargs): @@ -120,14 +138,13 @@ def mock_shutil_which(*args, **kwargs): # Call the function and expect FileNotFoundError with pytest.raises(FileNotFoundError): diamond.check_tool_exists("non_existing_tool") - - - def test_run_diamond_tool_found(monkeypatch): - monkeypatch.setattr(sys, "exit", lambda x: None) # Patch sys.exit to avoid test interruption + monkeypatch.setattr( + sys, "exit", lambda x: None + ) # Patch sys.exit to avoid test interruption # Mocking subprocess.run def mock_subprocess_run(*args, **kwargs): @@ -136,7 +153,10 @@ def __init__(self, returncode): self.returncode = returncode # Simulating successful run of diamond command - if args[0] == "diamond blastp --outfmt 6 --max-target-seqs 1 --query test.faa -o output.txt --threads 1 --db db --query-cover 80 --subject-cover 80 --id 30 --evalue 1e-05 --block-size 2 2> log.txt": + if ( + args[0] + == "diamond blastp --outfmt 6 --max-target-seqs 1 --query test.faa -o output.txt --threads 1 --db db --query-cover 80 --subject-cover 80 --id 30 --evalue 1e-05 --block-size 2 2> log.txt" + ): return CompletedProcess(0) monkeypatch.setattr(subprocess, "run", mock_subprocess_run) @@ -144,8 +164,16 @@ def __init__(self, returncode): # Call the function diamond.run( - "test.faa", "output.txt", "db", "log.txt", threads=1, query_cover=80, subject_cover=80, percent_id=30, - evalue=1e-05, low_mem=False + "test.faa", + "output.txt", + "db", + "log.txt", + threads=1, + query_cover=80, + subject_cover=80, + percent_id=30, + evalue=1e-05, + low_mem=False, ) @@ -157,10 +185,18 @@ def mock_check_tool_exists(*args, **kwargs): monkeypatch.setattr(logging, "error", lambda x: None) # Avoid logging during test # Call the function and expect it to raise FileNotFoundError - with patch('sys.exit') as mock_exit: + with patch("sys.exit") as mock_exit: diamond.run( - "test.faa", "output.txt", "db", "log.txt", threads=1, query_cover=80, subject_cover=80, percent_id=30, - evalue=1e-05, low_mem=False + "test.faa", + "output.txt", + "db", + "log.txt", + threads=1, + query_cover=80, + subject_cover=80, + percent_id=30, + evalue=1e-05, + low_mem=False, ) mock_exit.assert_called_once_with(1) @@ -172,22 +208,33 @@ def test_get_contig_to_kegg_id(): # Mocked dataframe representing the data read from the Diamond result file mocked_data = { - "ProteinID": ["contig1_protein1", "contig1_protein2", "contig2_protein1", "contig2_protein2"], - "annotation": ["protein1_annotation~K12345", "protein2_annotation~K67890", "protein3_annotation~K23456", "protein4_annotation~K66666"] + "ProteinID": [ + "contig1_protein1", + "contig1_protein2", + "contig2_protein1", + "contig2_protein2", + ], + "annotation": [ + "protein1_annotation~K12345", + "protein2_annotation~K67890", + "protein3_annotation~K23456", + "protein4_annotation~K66666", + ], } mocked_df = pd.DataFrame(mocked_data) # Mocked return values for keggData.KeggCalculator() and KeggCalc.return_default_values_from_category() class MockedKeggCalculator: def return_default_values_from_category(self, category): - return {"K12345": 2, "K67890": 1, "K23456":3} + return {"K12345": 2, "K67890": 1, "K23456": 3} - mocked_kegg_calculator = MockedKeggCalculator() + mocked_kegg_calculator = MockedKeggCalculator() # Mocking relevant functions and classes used within the function - with patch("pandas.read_csv", return_value=mocked_df), \ - patch("checkm2.keggData.KeggCalculator", return_value=mocked_kegg_calculator): - + with patch("pandas.read_csv", return_value=mocked_df), patch( + "checkm2.keggData.KeggCalculator", return_value=mocked_kegg_calculator + ): + # Call the function result = diamond.get_contig_to_kegg_id(diamond_result_file) @@ -196,7 +243,7 @@ def return_default_values_from_category(self, category): # Define the expected output based on the mocked data expected_result = { "contig1": Counter({"K12345": 1, "K67890": 1}), - "contig2": Counter({"K23456": 1}) + "contig2": Counter({"K23456": 1}), } # Check if the function output matches the expected result diff --git a/tests/io_manager_test.py b/tests/io_manager_test.py index 3a44368..3d857ab 100644 --- a/tests/io_manager_test.py +++ b/tests/io_manager_test.py @@ -4,10 +4,19 @@ from unittest.mock import patch - - class Bin: - def __init__(self, bin_id, origin, name, completeness, contamination, score, length, N50, contigs): + def __init__( + self, + bin_id, + origin, + name, + completeness, + contamination, + score, + length, + N50, + contigs, + ): self.id = bin_id self.origin = {origin} self.name = name @@ -18,48 +27,51 @@ def __init__(self, bin_id, origin, name, completeness, contamination, score, len self.N50 = N50 self.contigs = contigs + @pytest.fixture def bin1(): - return Bin(1, 'origin1', 'name1', 90, 5, 80, 1000, 500, ['contig1', 'contig3']) + return Bin(1, "origin1", "name1", 90, 5, 80, 1000, 500, ["contig1", "contig3"]) + @pytest.fixture def bin2(): - return Bin(2, 'origin2', 'name2', 85, 8, 75, 1200, 600, ['contig2', 'contig4']) + return Bin(2, "origin2", "name2", 85, 8, 75, 1200, 600, ["contig2", "contig4"]) def test_infer_bin_name_from_bin_inputs(): # Mock input data - input_bins = [ - '/path/to/bin1', - '/path/to/bin2', - '/path/to/bin3' - ] + input_bins = ["/path/to/bin1", "/path/to/bin2", "/path/to/bin3"] # Call the function - result = io_manager.infer_bin_set_names_from_input_paths(list(map(Path, input_bins))) + result = io_manager.infer_bin_set_names_from_input_paths( + list(map(Path, input_bins)) + ) # Define the expected output expected_result = { - 'bin1': Path('/path/to/bin1'), - 'bin2': Path('/path/to/bin2'), - 'bin3': Path('/path/to/bin3') + "bin1": Path("/path/to/bin1"), + "bin2": Path("/path/to/bin2"), + "bin3": Path("/path/to/bin3"), } # Check if the output matches the expected dictionary assert result == expected_result + def test_infer_bin_name_from_single_path(): # Mock input data input_bins = [ - '/path/to/bin1', + "/path/to/bin1", ] # Call the function - result = io_manager.infer_bin_set_names_from_input_paths(list(map(Path, input_bins))) + result = io_manager.infer_bin_set_names_from_input_paths( + list(map(Path, input_bins)) + ) # Define the expected output expected_result = { - '/path/to/bin1': Path('/path/to/bin1'), + "/path/to/bin1": Path("/path/to/bin1"), } # Check if the output matches the expected dictionary @@ -68,20 +80,18 @@ def test_infer_bin_name_from_single_path(): def test_infer_bin_name_from_bin_table_inputs(): # Mock input data - input_bins = [ - '/path/to/bin1.tsv', - '/path/to/bin2.tsv', - '/path/to/bin3.tsv' - ] + input_bins = ["/path/to/bin1.tsv", "/path/to/bin2.tsv", "/path/to/bin3.tsv"] # Call the function - result = io_manager.infer_bin_set_names_from_input_paths(list(map(Path, input_bins))) + result = io_manager.infer_bin_set_names_from_input_paths( + list(map(Path, input_bins)) + ) # Define the expected output expected_result = { - 'bin1': Path('/path/to/bin1.tsv'), - 'bin2': Path('/path/to/bin2.tsv'), - 'bin3': Path('/path/to/bin3.tsv') + "bin1": Path("/path/to/bin1.tsv"), + "bin2": Path("/path/to/bin2.tsv"), + "bin3": Path("/path/to/bin3.tsv"), } # Check if the output matches the expected dictionary @@ -90,76 +100,114 @@ def test_infer_bin_name_from_bin_table_inputs(): def test_infer_bin_name_from_bin_table_with_different_ext(): # Mock input data - input_bins = [ - '/path/to/bin1.tsv', - '/path/to/bin2.tsv', - '/path/to/bin3.txt' - ] + input_bins = ["/path/to/bin1.tsv", "/path/to/bin2.tsv", "/path/to/bin3.txt"] # Call the function - result = io_manager.infer_bin_set_names_from_input_paths(list(map(Path, input_bins))) + result = io_manager.infer_bin_set_names_from_input_paths( + list(map(Path, input_bins)) + ) # Define the expected output expected_result = { - 'bin1.tsv': Path('/path/to/bin1.tsv'), - 'bin2.tsv': Path('/path/to/bin2.tsv'), - 'bin3.txt': Path('/path/to/bin3.txt') + "bin1.tsv": Path("/path/to/bin1.tsv"), + "bin2.tsv": Path("/path/to/bin2.tsv"), + "bin3.txt": Path("/path/to/bin3.txt"), } # Check if the output matches the expected dictionary assert result == expected_result + def test_infer_bin_name_from_bin_table_with_different_dir(): # Mock input data input_bins = [ - '/path/to/bins', - '/path2/result_bins', - '/path2/result/bins', + "/path/to/bins", + "/path2/result_bins", + "/path2/result/bins", ] # Call the function - result = io_manager.infer_bin_set_names_from_input_paths(list(map(Path, input_bins))) + result = io_manager.infer_bin_set_names_from_input_paths( + list(map(Path, input_bins)) + ) # Define the expected output expected_result = { - 'path/to/bins' : Path('/path/to/bins'), - 'path2/result_bins': Path('/path2/result_bins'), - 'path2/result/bins': Path('/path2/result/bins'), + "path/to/bins": Path("/path/to/bins"), + "path2/result_bins": Path("/path2/result_bins"), + "path2/result/bins": Path("/path2/result/bins"), } # Check if the output matches the expected dictionary assert result == expected_result - + + def test_get_paths_common_prefix_suffix(): # Test case 1: No paths provided assert io_manager.get_paths_common_prefix_suffix([]) == ([], [], []) # # Test case 2: Single path - assert io_manager.get_paths_common_prefix_suffix([Path('/home/user/project')]) == (['/', 'home', 'user', 'project'], ['/', 'home', 'user', 'project'], []) + assert io_manager.get_paths_common_prefix_suffix([Path("/home/user/project")]) == ( + ["/", "home", "user", "project"], + ["/", "home", "user", "project"], + [], + ) # Test case 3: Multiple paths with common prefix and suffix - paths = [Path('/home/user/project/src'), Path('/home/user/project/docs'), Path('/home/user/project/tests')] - assert io_manager.get_paths_common_prefix_suffix(paths) == (['/', 'home', 'user', 'project'], [], []) + paths = [ + Path("/home/user/project/src"), + Path("/home/user/project/docs"), + Path("/home/user/project/tests"), + ] + assert io_manager.get_paths_common_prefix_suffix(paths) == ( + ["/", "home", "user", "project"], + [], + [], + ) # Test case 4: Multiple paths with no common prefix or suffix - paths = [Path('/var/log/syslog'), Path('/usr/local/bin/python'), Path('/etc/nginx/nginx.conf')] - assert io_manager.get_paths_common_prefix_suffix(paths) == (['/'], [], []) + paths = [ + Path("/var/log/syslog"), + Path("/usr/local/bin/python"), + Path("/etc/nginx/nginx.conf"), + ] + assert io_manager.get_paths_common_prefix_suffix(paths) == (["/"], [], []) # Test case 5: Multiple paths with common suffix - paths = [Path('/home/user/docs/report.txt'), Path('/home/admin/docs/report.txt')] - assert io_manager.get_paths_common_prefix_suffix(paths) == (['/', 'home'], ['docs', 'report.txt'], ['.txt']) + paths = [Path("/home/user/docs/report.txt"), Path("/home/admin/docs/report.txt")] + assert io_manager.get_paths_common_prefix_suffix(paths) == ( + ["/", "home"], + ["docs", "report.txt"], + [".txt"], + ) # Test case 6: Paths with a deeper common prefix and suffix - paths = [Path('/data/project_a/results/output.txt'), Path('/data/project_b/results/output.txt')] - assert io_manager.get_paths_common_prefix_suffix(paths) == (['/', 'data'], ['results', 'output.txt'], ['.txt']) + paths = [ + Path("/data/project_a/results/output.txt"), + Path("/data/project_b/results/output.txt"), + ] + assert io_manager.get_paths_common_prefix_suffix(paths) == ( + ["/", "data"], + ["results", "output.txt"], + [".txt"], + ) # Test case 7: Paths with only the root as common prefix and different suffix - paths = [Path('/project_a/output.txt'), Path('/project_b/output.txt')] - assert io_manager.get_paths_common_prefix_suffix(paths) == (['/'], ['output.txt'], ['.txt']) + paths = [Path("/project_a/output.txt"), Path("/project_b/output.txt")] + assert io_manager.get_paths_common_prefix_suffix(paths) == ( + ["/"], + ["output.txt"], + [".txt"], + ) # Test case 8: Paths with only the root as common prefix and different suffix - paths = [Path('/project_a/output.txt'), Path('/project_a/output.tsv')] - assert io_manager.get_paths_common_prefix_suffix(paths) == (['/', 'project_a'], [], []) + paths = [Path("/project_a/output.txt"), Path("/project_a/output.tsv")] + assert io_manager.get_paths_common_prefix_suffix(paths) == ( + ["/", "project_a"], + [], + [], + ) + def test_write_bin_info(tmp_path, bin1, bin2): # Mock input data @@ -175,7 +223,10 @@ def test_write_bin_info(tmp_path, bin1, bin2): with open(output_file, "r") as f: content = f.read() - assert "bin_id\torigin\tname\tcompleteness\tcontamination\tscore\tsize\tN50\tcontig_count" in content + assert ( + "bin_id\torigin\tname\tcompleteness\tcontamination\tscore\tsize\tN50\tcontig_count" + in content + ) assert "1\torigin1\tname1\t90\t5\t80\t1000\t500\t2" in content assert "2\torigin2\tname2\t85\t8\t75\t1200\t600\t2" in content @@ -194,13 +245,14 @@ def test_write_bin_info_add_contig(tmp_path, bin1, bin2): with open(output_file, "r") as f: content = f.read() - assert "bin_id\torigin\tname\tcompleteness\tcontamination\tscore\tsize\tN50\tcontig_count\tcontigs" in content + assert ( + "bin_id\torigin\tname\tcompleteness\tcontamination\tscore\tsize\tN50\tcontig_count\tcontigs" + in content + ) assert "1\torigin1\tname1\t90\t5\t80\t1000\t500\t2\tcontig1;contig3" in content assert "2\torigin2\tname2\t85\t8\t75\t1200\t600\t2\tcontig2;contig4" in content - - def test_write_bins_fasta(tmp_path, bin1, bin2): # Mock input data contigs_fasta = tmp_path / "contigs.fasta" @@ -226,12 +278,12 @@ def test_write_bins_fasta(tmp_path, bin1, bin2): with open(outdir / "bin_2.fa", "r") as bin2_file: assert bin2_file.read() == ">contig2\nTGCA\n>contig4\nCCCC\n" - + def test_check_contig_consistency_error(): # Mock input data - contigs_from_assembly = ['contig1', 'contig2', 'contig3'] - contigs_from_bins = ['contig2', 'contig3', 'contig4'] + contigs_from_assembly = ["contig1", "contig2", "contig3"] + contigs_from_bins = ["contig2", "contig3", "contig4"] assembly_file = "assembly.fasta" elsewhere_file = "external.fasta" @@ -241,16 +293,18 @@ def test_check_contig_consistency_error(): contigs_from_assembly, contigs_from_bins, assembly_file, elsewhere_file ) + def test_check_contig_consistency_no_error(): # Mock input data - contigs_from_assembly = ['contig1', 'contig2', 'contig3', 'contig4'] - contigs_from_bins = ['contig1', 'contig2', 'contig3'] + contigs_from_assembly = ["contig1", "contig2", "contig3", "contig4"] + contigs_from_bins = ["contig1", "contig2", "contig3"] assembly_file = "assembly.fasta" elsewhere_file = "external.fasta" io_manager.check_contig_consistency( - contigs_from_assembly, contigs_from_bins, assembly_file, elsewhere_file - ) + contigs_from_assembly, contigs_from_bins, assembly_file, elsewhere_file + ) + @pytest.fixture def temp_files(tmp_path): @@ -261,6 +315,7 @@ def temp_files(tmp_path): diamond_result_file.touch() yield str(faa_file), str(diamond_result_file) + def test_check_resume_file_exists(temp_files, caplog): # Test when both files exist faa_file, diamond_result_file = temp_files @@ -268,6 +323,7 @@ def test_check_resume_file_exists(temp_files, caplog): assert "Protein file" not in caplog.text assert "Diamond result file" not in caplog.text + def test_check_resume_file_missing_faa(temp_files, caplog): # Test when faa_file is missing _, diamond_result_file = temp_files @@ -276,23 +332,25 @@ def test_check_resume_file_missing_faa(temp_files, caplog): assert "Protein file" in caplog.text assert "Diamond result file" not in caplog.text + def test_check_resume_file_missing_diamond(temp_files, caplog): # Test when diamond_result_file is missing faa_file, _ = temp_files with pytest.raises(FileNotFoundError): - io_manager.check_resume_file(Path(faa_file), Path("nonexistent_diamond_result.txt")) + io_manager.check_resume_file( + Path(faa_file), Path("nonexistent_diamond_result.txt") + ) assert "Protein file" not in caplog.text assert "Diamond result file" in caplog.text -@patch('binette.io_manager.write_bin_info') -def test_write_original_bin_metrics(mock_write_bin_info, bin1,bin2, tmp_path): +@patch("binette.io_manager.write_bin_info") +def test_write_original_bin_metrics(mock_write_bin_info, bin1, bin2, tmp_path): # Test that `write_original_bin_metrics` correctly writes bin metrics to files - temp_directory = tmp_path / "test_output" + temp_directory = tmp_path / "test_output" - mock_bins = {"set1":{bin1}, - "set2":{bin2}} + mock_bins = {"set1": {bin1}, "set2": {bin2}} # Call the function with mock data io_manager.write_original_bin_metrics(mock_bins, temp_directory) @@ -302,14 +360,18 @@ def test_write_original_bin_metrics(mock_write_bin_info, bin1,bin2, tmp_path): # Check that the correct files are created expected_files = [ temp_directory / "input_bins_1.set1.tsv", - temp_directory / "input_bins_2.set2.tsv" + temp_directory / "input_bins_2.set2.tsv", ] - assert temp_directory.exists(), f"Expected temp_directory {temp_directory} was not created." + assert ( + temp_directory.exists() + ), f"Expected temp_directory {temp_directory} was not created." # Check if `write_bin_info` was called correctly - assert mock_write_bin_info.call_count == 2, "write_bin_info should be called once for each bin set." + assert ( + mock_write_bin_info.call_count == 2 + ), "write_bin_info should be called once for each bin set." # Verify the specific calls to `write_bin_info` - mock_write_bin_info.assert_any_call(mock_bins['set1'], expected_files[0]) - mock_write_bin_info.assert_any_call(mock_bins['set2'], expected_files[1]) \ No newline at end of file + mock_write_bin_info.assert_any_call(mock_bins["set1"], expected_files[0]) + mock_write_bin_info.assert_any_call(mock_bins["set2"], expected_files[1]) diff --git a/tests/main_binette_test.py b/tests/main_binette_test.py index 6c7b2cd..8de8238 100644 --- a/tests/main_binette_test.py +++ b/tests/main_binette_test.py @@ -1,7 +1,16 @@ - import pytest import logging -from binette.main import log_selected_bin_info, select_bins_and_write_them, manage_protein_alignement, parse_input_files, parse_arguments, init_logging, main, UniqueStore, is_valid_file +from binette.main import ( + log_selected_bin_info, + select_bins_and_write_them, + manage_protein_alignement, + parse_input_files, + parse_arguments, + init_logging, + main, + UniqueStore, + is_valid_file, +) from binette.bin_manager import Bin from binette import diamond, contig_manager, cds import os @@ -13,6 +22,7 @@ from argparse import ArgumentParser from pathlib import Path + @pytest.fixture def test_environment(tmp_path: Path): """ @@ -28,6 +38,7 @@ def test_environment(tmp_path: Path): return folder1, folder2, contigs_file + @pytest.fixture def bins(): b1 = Bin(contigs={"contig1"}, origin="set1", name="bin1") @@ -40,11 +51,11 @@ def bins(): return [b1, b2, b3] + def test_log_selected_bin_info(caplog, bins): caplog.set_level(logging.INFO) - hq_min_completeness = 85 hq_max_conta = 15 @@ -52,8 +63,7 @@ def test_log_selected_bin_info(caplog, bins): log_selected_bin_info(bins, hq_min_completeness, hq_max_conta) # Check if the logs contain expected messages - expected_logs ="2/3 selected bins have a high quality (completeness >= 85 and contamination <= 15)." - + expected_logs = "2/3 selected bins have a high quality (completeness >= 85 and contamination <= 15)." assert expected_logs in caplog.text @@ -64,26 +74,33 @@ def test_select_bins_and_write_them(tmp_path, tmpdir, bins): contigs_fasta = os.path.join(str(outdir), "contigs.fasta") final_bin_report = os.path.join(str(outdir), "final_bin_report.tsv") - index_to_contig={"contig1":"contig1", "contig2": "contig2", "contig3":"contig3"} + index_to_contig = {"contig1": "contig1", "contig2": "contig2", "contig3": "contig3"} contigs_fasta = tmp_path / "contigs.fasta" contigs_fasta_content = ( - ">contig1\nACGT\n>contig2\nTGCA\n>contig3\nAAAA\n>contig4\nCCCC\n" + ">contig1\nACGT\n>contig2\nTGCA\n>contig3\nAAAA\n>contig4\nCCCC\n" ) contigs_fasta.write_text(contigs_fasta_content) b1, b2, b3 = bins - # Run the function with test data selected_bins = select_bins_and_write_them( - set(bins), contigs_fasta, Path(final_bin_report), min_completeness=60, index_to_contig=index_to_contig, outdir=outdir, debug=True + set(bins), + contigs_fasta, + Path(final_bin_report), + min_completeness=60, + index_to_contig=index_to_contig, + outdir=outdir, + debug=True, ) # Assertions to check the function output or file existence assert isinstance(selected_bins, list) assert os.path.isfile(final_bin_report) - assert selected_bins == bins[:2] # The third bin is overlapping with the second one and has a worse score so it is not selected. + assert ( + selected_bins == bins[:2] + ) # The third bin is overlapping with the second one and has a worse score so it is not selected. with open(outdir / f"final_bins/bin_{b1.id}.fa", "r") as bin1_file: assert bin1_file.read() == ">contig1\nACGT\n" @@ -94,27 +111,25 @@ def test_select_bins_and_write_them(tmp_path, tmpdir, bins): assert not os.path.isfile(outdir / f"final_bins/bin_{b3.id}.fa") - def test_manage_protein_alignement_resume(tmp_path): # Create temporary directories and files for testing faa_file = tmp_path / "proteins.faa" faa_file_content = ( - ">contig1_1\nMCGT\n>contig2_1\nTGCA\n>contig2_2\nAAAA\n>contig3_1\nCCCC\n" + ">contig1_1\nMCGT\n>contig2_1\nTGCA\n>contig2_2\nAAAA\n>contig3_1\nCCCC\n" ) - contig_to_length={"contig1":40, "contig2":80, "contig3":20} + contig_to_length = {"contig1": 40, "contig2": 80, "contig3": 20} faa_file.write_text(faa_file_content) contig_to_kegg_id = { "contig1": Counter({"K12345": 1, "K67890": 1}), - "contig2": Counter({"K23456": 1}) + "contig2": Counter({"K23456": 1}), } - with patch("binette.diamond.get_contig_to_kegg_id", return_value=contig_to_kegg_id): - + # Call the function # Run the function with test data @@ -128,7 +143,7 @@ def test_manage_protein_alignement_resume(tmp_path): threads=1, use_existing_protein_file=True, resume_diamond=True, - low_mem=False + low_mem=False, ) # Assertions to check the function output or file existence @@ -141,27 +156,24 @@ def test_manage_protein_alignement_not_resume(tmpdir, tmp_path): # Create temporary directories and files for testing faa_file = tmp_path / "proteins.faa" - faa_file_content = ( - ">contig1_1\nMLKPACGT\n>contig2_1\nMMMKPTGCA\n>contig2_2\nMMMAAAA\n>contig3_1\nMLPALP\n" - ) + faa_file_content = ">contig1_1\nMLKPACGT\n>contig2_1\nMMMKPTGCA\n>contig2_2\nMMMAAAA\n>contig3_1\nMLPALP\n" - contig_to_length={"contig1":40, "contig2":80, "contig3":20} + contig_to_length = {"contig1": 40, "contig2": 80, "contig3": 20} faa_file.write_text(faa_file_content) - contigs_fasta = os.path.join(str(tmpdir), "contigs.fasta") diamond_result_file = os.path.join(str(tmpdir), "diamond_results.tsv") contig_to_kegg_id = { "contig1": Counter({"K12345": 1, "K67890": 1}), - "contig2": Counter({"K23456": 1}) + "contig2": Counter({"K23456": 1}), } + with patch( + "binette.diamond.get_contig_to_kegg_id", return_value=contig_to_kegg_id + ), patch("binette.diamond.run", return_value=None): - with patch("binette.diamond.get_contig_to_kegg_id", return_value=contig_to_kegg_id), \ - patch("binette.diamond.run", return_value=None): - # Call the function contig_to_kegg_counter, contig_to_genes = manage_protein_alignement( @@ -174,7 +186,7 @@ def test_manage_protein_alignement_not_resume(tmpdir, tmp_path): threads=1, use_existing_protein_file=True, resume_diamond=True, - low_mem=False + low_mem=False, ) # Assertions to check the function output or file existence @@ -191,14 +203,13 @@ def test_parse_input_files_with_contig2bin_tables(tmp_path): bin_set2.write_text("contig3\tbin2A\ncontig4\ttbin2B\n") fasta_file = tmp_path / "assembly.fasta" - fasta_file_content = ( - ">contig1\nACGT\n>contig2\nTGCA\n>contig3\nAAAA\n>contig4\nCCCC\n>contig5\nCGTCGCT\n" - ) + fasta_file_content = ">contig1\nACGT\n>contig2\nTGCA\n>contig3\nAAAA\n>contig4\nCCCC\n>contig5\nCGTCGCT\n" fasta_file.write_text(fasta_file_content) # Call the function and capture the return values - bin_set_name_to_bins, original_bins, contigs_in_bins, contig_to_length = parse_input_files(None, [bin_set1, bin_set2], fasta_file) - + bin_set_name_to_bins, original_bins, contigs_in_bins, contig_to_length = ( + parse_input_files(None, [bin_set1, bin_set2], fasta_file) + ) # # Perform assertions on the returned values assert isinstance(bin_set_name_to_bins, dict) @@ -206,20 +217,18 @@ def test_parse_input_files_with_contig2bin_tables(tmp_path): assert isinstance(contigs_in_bins, set) assert isinstance(contig_to_length, dict) - - assert set(bin_set_name_to_bins) == {'bin_set1', "bin_set2"} + assert set(bin_set_name_to_bins) == {"bin_set1", "bin_set2"} assert len(original_bins) == 4 - assert contigs_in_bins == {"contig1","contig2", "contig3","contig4"} + assert contigs_in_bins == {"contig1", "contig2", "contig3", "contig4"} assert len(contig_to_length) == 4 + def test_parse_input_files_with_contig2bin_tables_with_unknown_contig(tmp_path): bin_set3 = tmp_path / "bin_set3.tsv" bin_set3.write_text("contig3\tbin3A\ncontig44\ttbin3B\n") fasta_file = tmp_path / "assembly.fasta" - fasta_file_content = ( - ">contig1\nACGT\n>contig2\nTGCA\n>contig3\nAAAA\n>contig4\nCCCC\n>contig5\nCGTCGCT\n" - ) + fasta_file_content = ">contig1\nACGT\n>contig2\nTGCA\n>contig3\nAAAA\n>contig4\nCCCC\n>contig5\nCGTCGCT\n" fasta_file.write_text(fasta_file_content) with pytest.raises(ValueError): @@ -235,13 +244,13 @@ def test_parse_input_files_bin_dirs(create_temp_bin_directories, tmp_path): # Create temporary directories and files for testing fasta_file = tmp_path / "assembly.fasta" - fasta_file_content = ( - ">contig1\nACGT\n>contig2\nTGCA\n>contig3\nAAAA\n>contig4\nCCCC\n>contig5\nCGTCGCT\n" - ) + fasta_file_content = ">contig1\nACGT\n>contig2\nTGCA\n>contig3\nAAAA\n>contig4\nCCCC\n>contig5\nCGTCGCT\n" fasta_file.write_text(fasta_file_content) # Call the function and capture the return values - bin_set_name_to_bins, original_bins, contigs_in_bins, contig_to_length = parse_input_files(bin_dirs, contig2bin_tables, fasta_file) + bin_set_name_to_bins, original_bins, contigs_in_bins, contig_to_length = ( + parse_input_files(bin_dirs, contig2bin_tables, fasta_file) + ) # # Perform assertions on the returned values assert isinstance(bin_set_name_to_bins, dict) @@ -249,26 +258,32 @@ def test_parse_input_files_bin_dirs(create_temp_bin_directories, tmp_path): assert isinstance(contigs_in_bins, set) assert isinstance(contig_to_length, dict) - - assert set(bin_set_name_to_bins) == {'set1', 'set2'} + assert set(bin_set_name_to_bins) == {"set1", "set2"} assert len(original_bins) == 3 - assert contigs_in_bins == {"contig1","contig2", "contig3","contig4","contig5",} + assert contigs_in_bins == { + "contig1", + "contig2", + "contig3", + "contig4", + "contig5", + } assert len(contig_to_length) == 5 def test_argument_used_once(): - # Test UniqueStore class - parser = ArgumentParser(description='Test parser') - parser.add_argument('--example', action=UniqueStore, help='Example argument') - args = parser.parse_args(['--example', 'value']) - assert args.example == 'value' + # Test UniqueStore class + parser = ArgumentParser(description="Test parser") + parser.add_argument("--example", action=UniqueStore, help="Example argument") + args = parser.parse_args(["--example", "value"]) + assert args.example == "value" + def test_argument_used_multiple_times(): - # Test UniqueStore class - parser = ArgumentParser(description='Test parser') - parser.add_argument('--example', action=UniqueStore, help='Example argument') + # Test UniqueStore class + parser = ArgumentParser(description="Test parser") + parser.add_argument("--example", action=UniqueStore, help="Example argument") with pytest.raises(SystemExit): - parser.parse_args(['--example', 'value', '--example', 'value2']) + parser.parse_args(["--example", "value", "--example", "value2"]) def test_parse_arguments_required_arguments(test_environment): @@ -289,23 +304,37 @@ def test_parse_arguments_required_arguments(test_environment): def test_parse_arguments_optional_arguments(test_environment): # Test when required and optional arguments are provided - + # Create temporary directories and files - folder1, folder2, contigs_file = test_environment + folder1, folder2, contigs_file = test_environment # Parse arguments with existing files and directories - args = parse_arguments(["-d", str(folder1), str(folder2), "-c", str(contigs_file), "--threads", "4", "--outdir", "output"]) + args = parse_arguments( + [ + "-d", + str(folder1), + str(folder2), + "-c", + str(contigs_file), + "--threads", + "4", + "--outdir", + "output", + ] + ) assert args.bin_dirs == [folder1, folder2] assert args.contigs == contigs_file assert args.threads == 4 assert args.outdir == Path("output") + def test_parse_arguments_invalid_arguments(): # Test when invalid arguments are provided with pytest.raises(SystemExit): # In this case, required arguments are missing parse_arguments(["-t", "4"]) + def test_parse_arguments_help(): # Test the help message with pytest.raises(SystemExit) as pytest_wrapped_e: @@ -315,7 +344,7 @@ def test_parse_arguments_help(): def test_init_logging_command_line(caplog): - + caplog.set_level(logging.INFO) init_logging(verbose=True, debug=False) @@ -341,89 +370,127 @@ def test_manage_protein_alignment_no_resume(tmp_path): low_mem = False # Mock the necessary functions - with patch('binette.contig_manager.parse_fasta_file') as mock_parse_fasta_file, \ - patch('binette.cds.predict') as mock_predict, \ - patch('binette.diamond.get_checkm2_db') as mock_get_checkm2_db, \ - patch('binette.diamond.run') as mock_diamond_run, \ - patch('binette.diamond.get_contig_to_kegg_id') as mock_diamond_get_contig_to_kegg_id: - + with patch( + "binette.contig_manager.parse_fasta_file" + ) as mock_parse_fasta_file, patch("binette.cds.predict") as mock_predict, patch( + "binette.diamond.get_checkm2_db" + ) as mock_get_checkm2_db, patch( + "binette.diamond.run" + ) as mock_diamond_run, patch( + "binette.diamond.get_contig_to_kegg_id" + ) as mock_diamond_get_contig_to_kegg_id: + # Set the return value of the mocked functions mock_parse_fasta_file.return_value = [MagicMock(name="contig1")] mock_predict.return_value = {"contig1": ["gene1"]} - + # Call the function contig_to_kegg_counter, contig_to_genes = manage_protein_alignement( - faa_file, contigs_fasta, contig_to_length, contigs_in_bins, - diamond_result_file, checkm2_db, threads, resume, resume, low_mem + faa_file, + contigs_fasta, + contig_to_length, + contigs_in_bins, + diamond_result_file, + checkm2_db, + threads, + resume, + resume, + low_mem, ) - + # Assertions to check if functions were called mock_parse_fasta_file.assert_called_once_with(contigs_fasta.as_posix()) mock_predict.assert_called_once() mock_diamond_get_contig_to_kegg_id.assert_called_once() mock_diamond_run.assert_called_once_with( - faa_file.as_posix(), diamond_result_file.as_posix(), checkm2_db.as_posix(), f"{os.path.splitext(diamond_result_file.as_posix())[0]}.log", threads, low_mem=low_mem + faa_file.as_posix(), + diamond_result_file.as_posix(), + checkm2_db.as_posix(), + f"{os.path.splitext(diamond_result_file.as_posix())[0]}.log", + threads, + low_mem=low_mem, ) + def test_main_resume_when_not_possible(monkeypatch, test_environment): # Define or mock the necessary inputs/arguments folder1, folder2, contigs_file = test_environment # Mock sys.argv to use test_args - test_args = ["-d", str(folder1), str(folder2), "-c", str(contigs_file), + test_args = [ + "-d", + str(folder1), + str(folder2), + "-c", + str(contigs_file), # ... more arguments as required ... "--debug", - "--resume" + "--resume", ] - monkeypatch.setattr(sys, 'argv', ['your_script.py'] + test_args) + monkeypatch.setattr(sys, "argv", ["your_script.py"] + test_args) # Call the main function with pytest.raises(FileNotFoundError): main() + def test_main(monkeypatch, test_environment): # Define or mock the necessary inputs/arguments folder1, folder2, contigs_file = test_environment # Mock sys.argv to use test_args test_args = [ - "-d", str(folder1), str(folder2), "-c", str(contigs_file), + "-d", + str(folder1), + str(folder2), + "-c", + str(contigs_file), # ... more arguments as required ... - "--debug" + "--debug", ] - monkeypatch.setattr(sys, 'argv', ['your_script.py'] + test_args) + monkeypatch.setattr(sys, "argv", ["your_script.py"] + test_args) # Mock the necessary functions - with patch('binette.main.parse_input_files') as mock_parse_input_files, \ - patch('binette.main.manage_protein_alignement') as mock_manage_protein_alignement, \ - patch('binette.contig_manager.apply_contig_index') as mock_apply_contig_index, \ - patch('binette.bin_manager.rename_bin_contigs') as mock_rename_bin_contigs, \ - patch('binette.bin_manager.create_intermediate_bins') as mock_create_intermediate_bins, \ - patch('binette.bin_quality.add_bin_metrics') as mock_add_bin_metrics, \ - patch('binette.main.log_selected_bin_info') as mock_log_selected_bin_info, \ - patch('binette.contig_manager.make_contig_index') as mock_make_contig_index, \ - patch('binette.io_manager.write_original_bin_metrics') as mock_write_original_bin_metrics, \ - patch('binette.main.select_bins_and_write_them') as mock_select_bins_and_write_them: - + with patch("binette.main.parse_input_files") as mock_parse_input_files, patch( + "binette.main.manage_protein_alignement" + ) as mock_manage_protein_alignement, patch( + "binette.contig_manager.apply_contig_index" + ) as mock_apply_contig_index, patch( + "binette.bin_manager.rename_bin_contigs" + ) as mock_rename_bin_contigs, patch( + "binette.bin_manager.create_intermediate_bins" + ) as mock_create_intermediate_bins, patch( + "binette.bin_quality.add_bin_metrics" + ) as mock_add_bin_metrics, patch( + "binette.main.log_selected_bin_info" + ) as mock_log_selected_bin_info, patch( + "binette.contig_manager.make_contig_index" + ) as mock_make_contig_index, patch( + "binette.io_manager.write_original_bin_metrics" + ) as mock_write_original_bin_metrics, patch( + "binette.main.select_bins_and_write_them" + ) as mock_select_bins_and_write_them: + # Set return values for mocked functions if needed mock_parse_input_files.return_value = (None, None, None, None) - mock_manage_protein_alignement.return_value = ({"contig1": 1}, {"contig1": ["gene1"]}) + mock_manage_protein_alignement.return_value = ( + {"contig1": 1}, + {"contig1": ["gene1"]}, + ) mock_make_contig_index.return_value = ({}, {}) mock_apply_contig_index.return_value = MagicMock() mock_rename_bin_contigs.return_value = MagicMock() mock_create_intermediate_bins.return_value = MagicMock() mock_add_bin_metrics.return_value = MagicMock() mock_log_selected_bin_info.return_value = MagicMock() - - - main() + main() # Add assertions to ensure the mocks were called as expected mock_parse_input_files.assert_called_once() mock_manage_protein_alignement.assert_called_once() mock_rename_bin_contigs.assert_called_once() mock_create_intermediate_bins.assert_called_once() - + mock_log_selected_bin_info.assert_called_once() mock_select_bins_and_write_them.assert_called_once() mock_write_original_bin_metrics.assert_called_once() @@ -444,6 +511,7 @@ def test_is_valid_file_existing_file(tmp_path: Path): result = is_valid_file(parser, str(test_file)) assert result == test_file + def test_is_valid_file_non_existing_file(): """Test is_valid_file with a file that does not exist.""" parser = ArgumentParser() @@ -451,4 +519,4 @@ def test_is_valid_file_non_existing_file(): # Expect the function to call parser.error, which will raise a SystemExit exception with pytest.raises(SystemExit): - is_valid_file(parser, non_existing_file) \ No newline at end of file + is_valid_file(parser, non_existing_file)