From 05c8d527832202cf69e998bf0781c5c8f9413d6f Mon Sep 17 00:00:00 2001 From: Michal Puncochar Date: Mon, 9 Oct 2023 18:09:48 +0200 Subject: [PATCH] Optimize print clades only --- metaphlan/strainphlan.py | 125 +++++++++++++------------ metaphlan/utils/database_controller.py | 10 +- metaphlan/utils/sample2markers.py | 2 +- 3 files changed, 76 insertions(+), 61 deletions(-) diff --git a/metaphlan/strainphlan.py b/metaphlan/strainphlan.py index f347c82..d7009c5 100755 --- a/metaphlan/strainphlan.py +++ b/metaphlan/strainphlan.py @@ -10,12 +10,12 @@ import argparse as ap -import collections import io import os import re import tempfile import time +from collections import OrderedDict from shutil import copyfile, rmtree from typing import Iterable @@ -80,40 +80,40 @@ def filter_markers_matrix(self, markers_matrix, messages=True): """Filters the primary samples, references and markers based on the user defined thresholds Args: - markers_matrix (list[dict]): The list with the sample-to-markers information + markers_matrix (pd.DataFrame): The presence-absence data frame with index as samples and columns as markers messages (bool): Whether to be verbose and halt execution when less than 4 samples are left. Defaults to True. """ - mm = pd.DataFrame.from_records(markers_matrix, index='sample') # df with index samples and columns markers # Step I: samples with not enough markers are treated as secondary # here the percentage is calculated from the number of markers in at least one sample, it can be less than # the total number of available markers in the database for the clade - n_samples_0, n_markers_0 = mm.shape + n_samples_0, n_markers_0 = markers_matrix.shape min_markers = max(self.sample_with_n_markers, n_markers_0 * self.sample_with_n_markers_perc / 100) - mm_primary = mm.loc[mm.sum(axis=1) >= min_markers] + mm_primary = markers_matrix.loc[markers_matrix.sum(axis=1) >= min_markers] # Step II: filter markers in not enough primary samples n_samples_primary, _ = mm_primary.shape min_samples = n_samples_primary * self.marker_in_n_samples_perc / 100 - mm = mm.loc[:, mm_primary.sum(axis=0) >= min_samples] + markers_matrix = markers_matrix.loc[:, mm_primary.sum(axis=0) >= min_samples] # Step III: filter samples with not enough markers # here the percentage is calculated from the remaining markers - _, n_markers_2 = mm.shape + _, n_markers_2 = markers_matrix.shape min_markers = max(self.sample_with_n_markers_after_filt, n_markers_2 * self.sample_with_n_markers_after_filt_perc / 100) - mm = mm.loc[mm.sum(axis=1) >= min_markers] - n_samples_3 = len(mm) + markers_matrix = markers_matrix.loc[markers_matrix.sum(axis=1) >= min_markers] + + n_samples_3 = len(markers_matrix) if n_samples_3 < 4: if messages: error(f"Phylogeny can not be inferred. Less than 4 samples remained after filtering.\n" f"{n_samples_3} / {n_samples_0} samples ({n_samples_primary} primary) " f"and {n_markers_2} / {n_markers_0 } markers remained.", exit=True) - return + return markers_matrix - self.cleaned_markers_matrix = mm + return markers_matrix def copy_filtered_references(self, markers_tmp_dir, filtered_samples): @@ -131,16 +131,19 @@ def copy_filtered_references(self, markers_tmp_dir, filtered_samples): copyfile(reference_marker, os.path.join(markers_tmp_dir, f"{reference_name}.fna.bz2")) - def matrix_markers_to_fasta(self): + def matrix_markers_to_fasta(self, markers_matrix): """For each sample, writes the FASTA files with the sequences of the filtered markers + Args: + markers_matrix (pd.DataFrame): + Returns: str: the temporary folder where the FASTA files were written """ markers_tmp_dir = os.path.join(self.tmp_dir, "{}.StrainPhlAn4".format(self.clade)) create_folder(markers_tmp_dir) - filtered_samples = set(self.cleaned_markers_matrix.index) - filtered_markers = set(self.cleaned_markers_matrix.columns) + filtered_samples = set(markers_matrix.index) + filtered_markers = set(markers_matrix.columns) self.copy_filtered_references(markers_tmp_dir, filtered_samples) execute_pool(((Strainphlan.sample_markers_to_fasta, sample, filtered_samples, filtered_markers, self.trim_sequences, markers_tmp_dir) for sample in self.samples), self.nprocs) @@ -176,11 +179,13 @@ def get_markers_from_references(self): """ if not self.clade_markers_file: self.database_controller.extract_markers([self.clade], self.tmp_dir) - self.clade_markers_file = os.path.join(self.tmp_dir, "{}.fna".format(self.clade)) + clade_markers_file = os.path.join(self.tmp_dir, "{}.fna".format(self.clade)) elif self.clade_markers_file.endswith(".bz2"): - self.clade_markers_file = decompress_bz2(self.clade_markers_file, self.tmp_dir) + clade_markers_file = decompress_bz2(self.clade_markers_file, self.tmp_dir) + else: + clade_markers_file = self.clade_markers_file - return execute_pool(((Strainphlan.process_reference, reference, self.tmp_dir, self.clade_markers_file, + return execute_pool(((Strainphlan.process_reference, reference, self.tmp_dir, clade_markers_file, self.clade_markers_names, self.trim_sequences) for reference in self.references), self.nprocs) @@ -319,9 +324,9 @@ def calculate_polymorphic_rates(self): df.to_csv(os.path.join(self.output_dir, f"{self.clade}.polymorphic"), sep='\t', index=False) - def write_info(self): + def write_info(self, markers_matrix): """Writes the information file for the execution""" - filtered_names = [self.sample_path_to_name(sample) for sample in self.cleaned_markers_matrix.index] + filtered_names = [self.sample_path_to_name(sample) for sample in markers_matrix.index] with open(os.path.join(self.output_dir, "{}.info".format(self.clade)), 'w') as info_file: info_file.write("Clade: {}\n".format(self.clade)) info_file.write( @@ -337,10 +342,9 @@ def write_info(self): info_file.write(f"\tMinimum number of markers to keep a sample after filtering: {self.sample_with_n_markers_after_filt}\n") info_file.write(f"\tMinimum percentage of markers to keep a sample after filtering: {self.sample_with_n_markers_after_filt_perc}\n") info_file.write(f"\tMinimum percentage of samples to keep a marker: {self.marker_in_n_samples_perc}\n") - info_file.write("Number of markers selected after filtering: {}\n".format( - len(self.cleaned_markers_matrix.columns))) + info_file.write("Number of markers selected after filtering: {}\n".format(len(markers_matrix.columns))) info_file.write("Number of samples after filtering: {}\n".format(len( - [sample for sample in self.samples if sample in self.cleaned_markers_matrix.index]))) + [sample for sample in self.samples if sample in markers_matrix.index]))) info_file.write("Number of references after filtering: {}\n".format(len([reference for reference in self.references if self.sample_path_to_name(reference) in filtered_names]))) info_file.write( "PhyloPhlan phylogenetic precision mode: {}\n".format(self.phylophlan_mode)) @@ -354,22 +358,30 @@ def detect_clades(self): Returns: dict: dictionary containing the number of samples a clade can be reconstructed from """ - markers2species = self.database_controller.get_markers2species() + markers2clade = self.database_controller.get_markers2clade() + clade2markers = self.database_controller.get_clade2markers() + sample2markers = {} + clades_to_check = set() + info('Processing samples...') + consensus_markers = execute_pool([(ConsensusMarkers.from_file, sample_path) for sample_path in self.samples], nprocs=self.nprocs) + for sample_path, cm in zip(self.samples, consensus_markers): + markers = [marker.name for marker in cm.consensus_markers + if (marker.name in markers2clade and marker.breadth >= self.breadth_thres)] + sample2markers[sample_path] = markers + clades_to_check.update((markers2clade[m] for m in markers)) + + info('Constructing the big marker matrix') + markers_matrix_big = [pd.Series({m: 1 for m in markers}, name=sample) for sample, markers in sample2markers.items()] + markers_matrix_big = pd.concat(markers_matrix_big, axis=1).fillna(0) + + info(f'Checking {len(clades_to_check)} species') species2samples = {} - species_to_check = set() - info('Detecting clades...') - for sample_path in self.samples: - sample = ConsensusMarkers.from_file(sample_path) - species_to_check.update((markers2species[marker.name] for marker in sample.consensus_markers if ( - marker.name in markers2species and marker.breadth >= self.breadth_thres))) - info(f' Will check {len(species_to_check)} species') - for species in species_to_check: - self.cleaned_markers_matrix = pd.DataFrame() - self.clade = species - self.clade_markers_names = self.database_controller.get_markers_for_clade(species) - self.filter_markers_samples(print_clades=True) - if len(self.cleaned_markers_matrix) >= 4: - species2samples[species] = len(self.cleaned_markers_matrix) + for clade in clades_to_check: + markers_matrix = markers_matrix_big.reindex(clade2markers[clade]).fillna(0) + markers_matrix_filtered = self.filter_markers_matrix(markers_matrix, messages=False) + n_samples, _ = markers_matrix_filtered.shape + if n_samples >= 4: + species2samples[clade] = n_samples info('Done.') return species2samples @@ -378,8 +390,7 @@ def print_clades(self): """Prints the clades detected in the reconstructed markers""" species2samples = self.detect_clades() info('Detected clades: ') - sorted_species2samples = collections.OrderedDict(sorted(species2samples.items(), key=lambda kv: kv[1], - reverse=True)) + sorted_species2samples = OrderedDict(sorted(species2samples.items(), key=lambda kv: kv[1], reverse=True)) with open(os.path.join(self.output_dir, 'print_clades_only.tsv'), 'w') as wf: wf.write('Clade\tNumber_of_samples\n') for species in sorted_species2samples: @@ -432,25 +443,23 @@ def interactive_clade_selection(self): self.clade = selected_clade - def filter_markers_samples(self, print_clades=False): + def filter_markers_samples(self): """Retrieves the filtered markers matrix with the filtered samples and references - Args: - print_clades (bool, optional): Whether it was run in the print_clade_only mode. Defaults to False. """ - if not print_clades: - info("Getting markers from samples...") + info("Getting markers from samples...") markers_matrix = self.get_markers_matrix_from_samples() - if not print_clades: - info("Done.") - if len(self.references) > 0: - info("Getting markers from references...") - markers_matrix += self.get_markers_from_references() - info("Done.") - info("Removing bad markers / samples...") - self.filter_markers_matrix(markers_matrix, messages=not print_clades) - if not print_clades: + info("Done.") + if len(self.references) > 0: + info("Getting markers from references...") + markers_matrix += self.get_markers_from_references() info("Done.") + info("Removing markers / samples...") + markers_matrix = pd.DataFrame.from_records(markers_matrix, index='sample') # df with index samples and columns markers + markers_matrix_filtered = self.filter_markers_matrix(markers_matrix, messages=True) + info("Done.") + + return markers_matrix_filtered def run_strainphlan(self): @@ -466,20 +475,19 @@ def run_strainphlan(self): self.tmp_dir = tempfile.mkdtemp(dir=self.tmp_dir) info("Done.") info("Filtering markers and samples...") - self.filter_markers_samples() + markers_matrix = self.filter_markers_samples() info("Done.") info("Writing samples as markers' FASTA files...") - samples_as_markers_dir = self.matrix_markers_to_fasta() + samples_as_markers_dir = self.matrix_markers_to_fasta(markers_matrix) info("Done.") info("Calculating polymorphic rates...") self.calculate_polymorphic_rates() info("Done.") info("Computing phylogeny...") - self.phylophlan_controller.compute_phylogeny(samples_as_markers_dir, len(self.cleaned_markers_matrix), - self.tmp_dir) + self.phylophlan_controller.compute_phylogeny(samples_as_markers_dir, len(markers_matrix), self.tmp_dir) info("Done.") info("Writing information file...") - self.write_info() + self.write_info(markers_matrix) info("Done.") if not self.debug: info("Removing temporary files...") @@ -489,7 +497,6 @@ def run_strainphlan(self): def __init__(self, args): self.clade_markers_names = None - self.cleaned_markers_matrix = None self.database_controller = MetaphlanDatabaseController(args.database) self.clade_markers_file = args.clade_markers self.samples = args.samples diff --git a/metaphlan/utils/database_controller.py b/metaphlan/utils/database_controller.py index d6cfed4..b739de6 100644 --- a/metaphlan/utils/database_controller.py +++ b/metaphlan/utils/database_controller.py @@ -6,6 +6,8 @@ import os import pickle import bz2 + +import pandas as pd from Bio import SeqIO try: from .external_exec import generate_markers_fasta @@ -37,7 +39,7 @@ def get_database_name(self): return self.database.split('/')[-1][:-4] - def get_markers2species(self): + def get_markers2clade(self): """Retrieve information from the MetaPhlAn database Returns: @@ -47,6 +49,12 @@ def get_markers2species(self): return {marker_name: marker_info['clade'] for marker_name, marker_info in self.database_pkl['markers'].items()} + def get_clade2markers(self): + markers2clade = self.get_markers2clade() + markers2clade = pd.Series(markers2clade) + return markers2clade.groupby(markers2clade).groups + + def get_markers_for_clade(self, clade): """ diff --git a/metaphlan/utils/sample2markers.py b/metaphlan/utils/sample2markers.py index 49580ed..49da600 100755 --- a/metaphlan/utils/sample2markers.py +++ b/metaphlan/utils/sample2markers.py @@ -200,7 +200,7 @@ def filter_consensuses(self, consensuses, coverages): list[ConsensusMarker]: """ markers2ext = self.database_controller.get_markers2ext() - markers2clade = self.database_controller.get_markers2species() + markers2clade = self.database_controller.get_markers2clade() clade2nmarkers = Counter(markers2clade.values()) markers_all = list(consensuses.keys())