Skip to content

Commit

Permalink
Optimize print clades only
Browse files Browse the repository at this point in the history
  • Loading branch information
PuncocharM committed Oct 9, 2023
1 parent 34c7444 commit 05c8d52
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 61 deletions.
125 changes: 66 additions & 59 deletions metaphlan/strainphlan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@


import argparse as ap
import collections
import io
import os
import re
import tempfile
import time
from collections import OrderedDict
from shutil import copyfile, rmtree
from typing import Iterable

Expand Down Expand Up @@ -80,40 +80,40 @@ def filter_markers_matrix(self, markers_matrix, messages=True):
"""Filters the primary samples, references and markers based on the user defined thresholds
Args:
markers_matrix (list[dict]): The list with the sample-to-markers information
markers_matrix (pd.DataFrame): The presence-absence data frame with index as samples and columns as markers
messages (bool): Whether to be verbose and halt execution when less than 4 samples are left.
Defaults to True.
"""
mm = pd.DataFrame.from_records(markers_matrix, index='sample') # df with index samples and columns markers

# Step I: samples with not enough markers are treated as secondary
# here the percentage is calculated from the number of markers in at least one sample, it can be less than
# the total number of available markers in the database for the clade
n_samples_0, n_markers_0 = mm.shape
n_samples_0, n_markers_0 = markers_matrix.shape
min_markers = max(self.sample_with_n_markers, n_markers_0 * self.sample_with_n_markers_perc / 100)
mm_primary = mm.loc[mm.sum(axis=1) >= min_markers]
mm_primary = markers_matrix.loc[markers_matrix.sum(axis=1) >= min_markers]

# Step II: filter markers in not enough primary samples
n_samples_primary, _ = mm_primary.shape
min_samples = n_samples_primary * self.marker_in_n_samples_perc / 100
mm = mm.loc[:, mm_primary.sum(axis=0) >= min_samples]
markers_matrix = markers_matrix.loc[:, mm_primary.sum(axis=0) >= min_samples]

# Step III: filter samples with not enough markers
# here the percentage is calculated from the remaining markers
_, n_markers_2 = mm.shape
_, n_markers_2 = markers_matrix.shape
min_markers = max(self.sample_with_n_markers_after_filt,
n_markers_2 * self.sample_with_n_markers_after_filt_perc / 100)
mm = mm.loc[mm.sum(axis=1) >= min_markers]
n_samples_3 = len(mm)
markers_matrix = markers_matrix.loc[markers_matrix.sum(axis=1) >= min_markers]

n_samples_3 = len(markers_matrix)
if n_samples_3 < 4:
if messages:
error(f"Phylogeny can not be inferred. Less than 4 samples remained after filtering.\n"
f"{n_samples_3} / {n_samples_0} samples ({n_samples_primary} primary) "
f"and {n_markers_2} / {n_markers_0 } markers remained.",
exit=True)
return
return markers_matrix

self.cleaned_markers_matrix = mm
return markers_matrix


def copy_filtered_references(self, markers_tmp_dir, filtered_samples):
Expand All @@ -131,16 +131,19 @@ def copy_filtered_references(self, markers_tmp_dir, filtered_samples):
copyfile(reference_marker, os.path.join(markers_tmp_dir, f"{reference_name}.fna.bz2"))


def matrix_markers_to_fasta(self):
def matrix_markers_to_fasta(self, markers_matrix):
"""For each sample, writes the FASTA files with the sequences of the filtered markers
Args:
markers_matrix (pd.DataFrame):
Returns:
str: the temporary folder where the FASTA files were written
"""
markers_tmp_dir = os.path.join(self.tmp_dir, "{}.StrainPhlAn4".format(self.clade))
create_folder(markers_tmp_dir)
filtered_samples = set(self.cleaned_markers_matrix.index)
filtered_markers = set(self.cleaned_markers_matrix.columns)
filtered_samples = set(markers_matrix.index)
filtered_markers = set(markers_matrix.columns)
self.copy_filtered_references(markers_tmp_dir, filtered_samples)
execute_pool(((Strainphlan.sample_markers_to_fasta, sample, filtered_samples, filtered_markers,
self.trim_sequences, markers_tmp_dir) for sample in self.samples), self.nprocs)
Expand Down Expand Up @@ -176,11 +179,13 @@ def get_markers_from_references(self):
"""
if not self.clade_markers_file:
self.database_controller.extract_markers([self.clade], self.tmp_dir)
self.clade_markers_file = os.path.join(self.tmp_dir, "{}.fna".format(self.clade))
clade_markers_file = os.path.join(self.tmp_dir, "{}.fna".format(self.clade))
elif self.clade_markers_file.endswith(".bz2"):
self.clade_markers_file = decompress_bz2(self.clade_markers_file, self.tmp_dir)
clade_markers_file = decompress_bz2(self.clade_markers_file, self.tmp_dir)
else:
clade_markers_file = self.clade_markers_file

return execute_pool(((Strainphlan.process_reference, reference, self.tmp_dir, self.clade_markers_file,
return execute_pool(((Strainphlan.process_reference, reference, self.tmp_dir, clade_markers_file,
self.clade_markers_names, self.trim_sequences)
for reference in self.references), self.nprocs)

Expand Down Expand Up @@ -319,9 +324,9 @@ def calculate_polymorphic_rates(self):
df.to_csv(os.path.join(self.output_dir, f"{self.clade}.polymorphic"), sep='\t', index=False)


def write_info(self):
def write_info(self, markers_matrix):
"""Writes the information file for the execution"""
filtered_names = [self.sample_path_to_name(sample) for sample in self.cleaned_markers_matrix.index]
filtered_names = [self.sample_path_to_name(sample) for sample in markers_matrix.index]
with open(os.path.join(self.output_dir, "{}.info".format(self.clade)), 'w') as info_file:
info_file.write("Clade: {}\n".format(self.clade))
info_file.write(
Expand All @@ -337,10 +342,9 @@ def write_info(self):
info_file.write(f"\tMinimum number of markers to keep a sample after filtering: {self.sample_with_n_markers_after_filt}\n")
info_file.write(f"\tMinimum percentage of markers to keep a sample after filtering: {self.sample_with_n_markers_after_filt_perc}\n")
info_file.write(f"\tMinimum percentage of samples to keep a marker: {self.marker_in_n_samples_perc}\n")
info_file.write("Number of markers selected after filtering: {}\n".format(
len(self.cleaned_markers_matrix.columns)))
info_file.write("Number of markers selected after filtering: {}\n".format(len(markers_matrix.columns)))
info_file.write("Number of samples after filtering: {}\n".format(len(
[sample for sample in self.samples if sample in self.cleaned_markers_matrix.index])))
[sample for sample in self.samples if sample in markers_matrix.index])))
info_file.write("Number of references after filtering: {}\n".format(len([reference for reference in self.references if self.sample_path_to_name(reference) in filtered_names])))
info_file.write(
"PhyloPhlan phylogenetic precision mode: {}\n".format(self.phylophlan_mode))
Expand All @@ -354,22 +358,30 @@ def detect_clades(self):
Returns:
dict: dictionary containing the number of samples a clade can be reconstructed from
"""
markers2species = self.database_controller.get_markers2species()
markers2clade = self.database_controller.get_markers2clade()
clade2markers = self.database_controller.get_clade2markers()
sample2markers = {}
clades_to_check = set()
info('Processing samples...')
consensus_markers = execute_pool([(ConsensusMarkers.from_file, sample_path) for sample_path in self.samples], nprocs=self.nprocs)
for sample_path, cm in zip(self.samples, consensus_markers):
markers = [marker.name for marker in cm.consensus_markers
if (marker.name in markers2clade and marker.breadth >= self.breadth_thres)]
sample2markers[sample_path] = markers
clades_to_check.update((markers2clade[m] for m in markers))

info('Constructing the big marker matrix')
markers_matrix_big = [pd.Series({m: 1 for m in markers}, name=sample) for sample, markers in sample2markers.items()]
markers_matrix_big = pd.concat(markers_matrix_big, axis=1).fillna(0)

info(f'Checking {len(clades_to_check)} species')
species2samples = {}
species_to_check = set()
info('Detecting clades...')
for sample_path in self.samples:
sample = ConsensusMarkers.from_file(sample_path)
species_to_check.update((markers2species[marker.name] for marker in sample.consensus_markers if (
marker.name in markers2species and marker.breadth >= self.breadth_thres)))
info(f' Will check {len(species_to_check)} species')
for species in species_to_check:
self.cleaned_markers_matrix = pd.DataFrame()
self.clade = species
self.clade_markers_names = self.database_controller.get_markers_for_clade(species)
self.filter_markers_samples(print_clades=True)
if len(self.cleaned_markers_matrix) >= 4:
species2samples[species] = len(self.cleaned_markers_matrix)
for clade in clades_to_check:
markers_matrix = markers_matrix_big.reindex(clade2markers[clade]).fillna(0)
markers_matrix_filtered = self.filter_markers_matrix(markers_matrix, messages=False)
n_samples, _ = markers_matrix_filtered.shape
if n_samples >= 4:
species2samples[clade] = n_samples
info('Done.')
return species2samples

Expand All @@ -378,8 +390,7 @@ def print_clades(self):
"""Prints the clades detected in the reconstructed markers"""
species2samples = self.detect_clades()
info('Detected clades: ')
sorted_species2samples = collections.OrderedDict(sorted(species2samples.items(), key=lambda kv: kv[1],
reverse=True))
sorted_species2samples = OrderedDict(sorted(species2samples.items(), key=lambda kv: kv[1], reverse=True))
with open(os.path.join(self.output_dir, 'print_clades_only.tsv'), 'w') as wf:
wf.write('Clade\tNumber_of_samples\n')
for species in sorted_species2samples:
Expand Down Expand Up @@ -432,25 +443,23 @@ def interactive_clade_selection(self):
self.clade = selected_clade


def filter_markers_samples(self, print_clades=False):
def filter_markers_samples(self):
"""Retrieves the filtered markers matrix with the filtered samples and references
Args:
print_clades (bool, optional): Whether it was run in the print_clade_only mode. Defaults to False.
"""
if not print_clades:
info("Getting markers from samples...")
info("Getting markers from samples...")
markers_matrix = self.get_markers_matrix_from_samples()
if not print_clades:
info("Done.")
if len(self.references) > 0:
info("Getting markers from references...")
markers_matrix += self.get_markers_from_references()
info("Done.")
info("Removing bad markers / samples...")
self.filter_markers_matrix(markers_matrix, messages=not print_clades)
if not print_clades:
info("Done.")
if len(self.references) > 0:
info("Getting markers from references...")
markers_matrix += self.get_markers_from_references()
info("Done.")
info("Removing markers / samples...")
markers_matrix = pd.DataFrame.from_records(markers_matrix, index='sample') # df with index samples and columns markers
markers_matrix_filtered = self.filter_markers_matrix(markers_matrix, messages=True)
info("Done.")

return markers_matrix_filtered


def run_strainphlan(self):
Expand All @@ -466,20 +475,19 @@ def run_strainphlan(self):
self.tmp_dir = tempfile.mkdtemp(dir=self.tmp_dir)
info("Done.")
info("Filtering markers and samples...")
self.filter_markers_samples()
markers_matrix = self.filter_markers_samples()
info("Done.")
info("Writing samples as markers' FASTA files...")
samples_as_markers_dir = self.matrix_markers_to_fasta()
samples_as_markers_dir = self.matrix_markers_to_fasta(markers_matrix)
info("Done.")
info("Calculating polymorphic rates...")
self.calculate_polymorphic_rates()
info("Done.")
info("Computing phylogeny...")
self.phylophlan_controller.compute_phylogeny(samples_as_markers_dir, len(self.cleaned_markers_matrix),
self.tmp_dir)
self.phylophlan_controller.compute_phylogeny(samples_as_markers_dir, len(markers_matrix), self.tmp_dir)
info("Done.")
info("Writing information file...")
self.write_info()
self.write_info(markers_matrix)
info("Done.")
if not self.debug:
info("Removing temporary files...")
Expand All @@ -489,7 +497,6 @@ def run_strainphlan(self):

def __init__(self, args):
self.clade_markers_names = None
self.cleaned_markers_matrix = None
self.database_controller = MetaphlanDatabaseController(args.database)
self.clade_markers_file = args.clade_markers
self.samples = args.samples
Expand Down
10 changes: 9 additions & 1 deletion metaphlan/utils/database_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import os
import pickle
import bz2

import pandas as pd
from Bio import SeqIO
try:
from .external_exec import generate_markers_fasta
Expand Down Expand Up @@ -37,7 +39,7 @@ def get_database_name(self):
return self.database.split('/')[-1][:-4]


def get_markers2species(self):
def get_markers2clade(self):
"""Retrieve information from the MetaPhlAn database
Returns:
Expand All @@ -47,6 +49,12 @@ def get_markers2species(self):
return {marker_name: marker_info['clade'] for marker_name, marker_info in self.database_pkl['markers'].items()}


def get_clade2markers(self):
markers2clade = self.get_markers2clade()
markers2clade = pd.Series(markers2clade)
return markers2clade.groupby(markers2clade).groups


def get_markers_for_clade(self, clade):
"""
Expand Down
2 changes: 1 addition & 1 deletion metaphlan/utils/sample2markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def filter_consensuses(self, consensuses, coverages):
list[ConsensusMarker]:
"""
markers2ext = self.database_controller.get_markers2ext()
markers2clade = self.database_controller.get_markers2species()
markers2clade = self.database_controller.get_markers2clade()
clade2nmarkers = Counter(markers2clade.values())

markers_all = list(consensuses.keys())
Expand Down

0 comments on commit 05c8d52

Please sign in to comment.