Skip to content

Commit

Permalink
Memory efficiency for print-clades-only and code improvements mostly …
Browse files Browse the repository at this point in the history
…for readability
  • Loading branch information
PuncocharM committed Oct 11, 2023
1 parent 05c8d52 commit f02065f
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 83 deletions.
63 changes: 32 additions & 31 deletions metaphlan/strainphlan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@

try:
from .utils import *
from .utils.external_exec import run_command
except ImportError:
from utils import *
from utils.external_exec import run_command


class Strainphlan:
Expand All @@ -50,9 +48,8 @@ def get_markers_matrix_from_samples(self):
list: the list containing the samples-to-markers information
"""

markers_matrix = execute_pool(((Strainphlan.get_matrix_for_sample, sample, self.clade_markers_names,
self.breadth_thres) for sample in self.samples), self.nprocs)
return markers_matrix
return execute_pool(((Strainphlan.get_matrix_for_sample, sample, self.clade_markers_names,
self.breadth_thres) for sample in self.samples), self.nprocs)


@staticmethod
Expand Down Expand Up @@ -216,7 +213,8 @@ def process_reference(cls, reference_file, tmp_dir, clade_markers_file, clade_ma

consensus_markers = ConsensusMarkers([ConsensusMarker(m, s) for m, s in ext_markers.items()])
reference_name = cls.sample_path_to_name(reference_file)
consensus_markers.to_fasta(os.path.join(reference_markers_dir, f'{reference_name}.fna.bz2'), trim_ends=trim_sequences)
consensus_markers.to_fasta(os.path.join(reference_markers_dir, f'{reference_name}.fna.bz2'),
trim_ends=trim_sequences)

markers_matrix = {'sample': reference_file}
markers_matrix.update({m: int(m in ext_markers) for m in clade_markers})
Expand All @@ -236,7 +234,7 @@ def extract_markers_from_genome(reference_file, clade_markers_file):
"""
# load the raw fasta data
with util_fun.openrt(reference_file) as f:
with openrt(reference_file) as f:
input_file_data = f.read()

# parse the fasta
Expand Down Expand Up @@ -329,27 +327,28 @@ def write_info(self, markers_matrix):
filtered_names = [self.sample_path_to_name(sample) for sample in markers_matrix.index]
with open(os.path.join(self.output_dir, "{}.info".format(self.clade)), 'w') as info_file:
info_file.write("Clade: {}\n".format(self.clade))
info_file.write(
"Number of samples: {}\n".format(len(self.samples)))
info_file.write(
"Number of references: {}\n".format(len(self.references)))
info_file.write("Number of samples: {}\n".format(len(self.samples)))
info_file.write("Number of references: {}\n".format(len(self.references)))
info_file.write("Number of available markers for the clade: {}\n".format(len(self.clade_markers_names)))
info_file.write("Filtering parameters:\n")
info_file.write("\tNumber of bases to remove when trimming markers: {}\n".format(
self.trim_sequences))
info_file.write(f"\tMinimum number of markers to make a sample primary: {self.sample_with_n_markers}\n")
info_file.write(f"\tMinimum percentage of markers to make a sample primary: {self.sample_with_n_markers_perc}\n")
info_file.write(f"\tMinimum number of markers to keep a sample after filtering: {self.sample_with_n_markers_after_filt}\n")
info_file.write(f"\tMinimum percentage of markers to keep a sample after filtering: {self.sample_with_n_markers_after_filt_perc}\n")
info_file.write("\tNumber of bases to remove when trimming markers: {}\n".format(self.trim_sequences))
info_file.write(f"\tMinimum number of markers to make a sample primary: "
f"{self.sample_with_n_markers}\n")
info_file.write(f"\tMinimum percentage of markers to make a sample primary: "
f"{self.sample_with_n_markers_perc}\n")
info_file.write(f"\tMinimum number of markers to keep a sample after filtering: "
f"{self.sample_with_n_markers_after_filt}\n")
info_file.write(f"\tMinimum percentage of markers to keep a sample after filtering: "
f"{self.sample_with_n_markers_after_filt_perc}\n")
info_file.write(f"\tMinimum percentage of samples to keep a marker: {self.marker_in_n_samples_perc}\n")
info_file.write("Number of markers selected after filtering: {}\n".format(len(markers_matrix.columns)))
info_file.write("Number of samples after filtering: {}\n".format(len(
[sample for sample in self.samples if sample in markers_matrix.index])))
info_file.write("Number of references after filtering: {}\n".format(len([reference for reference in self.references if self.sample_path_to_name(reference) in filtered_names])))
info_file.write(
"PhyloPhlan phylogenetic precision mode: {}\n".format(self.phylophlan_mode))
info_file.write(
"Number of processes used: {}\n".format(self.nprocs))
n_samples = len([sample for sample in self.samples if sample in markers_matrix.index])
info_file.write("Number of samples after filtering: {}\n".format(n_samples))
n_refs = len([reference for reference in self.references
if self.sample_path_to_name(reference) in filtered_names])
info_file.write("Number of references after filtering: {}\n".format(n_refs))
info_file.write("PhyloPhlan phylogenetic precision mode: {}\n".format(self.phylophlan_mode))
info_file.write("Number of processes used: {}\n".format(self.nprocs))


def detect_clades(self):
Expand All @@ -363,15 +362,17 @@ def detect_clades(self):
sample2markers = {}
clades_to_check = set()
info('Processing samples...')
consensus_markers = execute_pool([(ConsensusMarkers.from_file, sample_path) for sample_path in self.samples], nprocs=self.nprocs)
consensus_markers = execute_pool(((ConsensusMarkers.from_file, sample_path) for sample_path in self.samples),
nprocs=self.nprocs, return_generator=True)
for sample_path, cm in zip(self.samples, consensus_markers):
markers = [marker.name for marker in cm.consensus_markers
if (marker.name in markers2clade and marker.breadth >= self.breadth_thres)]
sample2markers[sample_path] = markers
clades_to_check.update((markers2clade[m] for m in markers))

info('Constructing the big marker matrix')
markers_matrix_big = [pd.Series({m: 1 for m in markers}, name=sample) for sample, markers in sample2markers.items()]
markers_matrix_big = [pd.Series({m: 1 for m in markers}, name=sample)
for sample, markers in sample2markers.items()]
markers_matrix_big = pd.concat(markers_matrix_big, axis=1).fillna(0)

info(f'Checking {len(clades_to_check)} species')
Expand Down Expand Up @@ -405,8 +406,7 @@ def interactive_clade_selection(self):
info("The clade has been specified at the species level, starting interactive clade selection...")
species2sgbs = self.database_controller.get_species2sgbs()
if self.clade not in species2sgbs:
error('The specified species "{}" is not present in the database. Exiting...'.format(
self.clade), exit=True)
error('The specified species "{}" is not present in the database. Exiting...'.format(self.clade), exit=True)
sgbs_in_species = dict(sorted(
species2sgbs[self.clade].items(), key=lambda item: item[1], reverse=True))
if self.non_interactive or len(sgbs_in_species) == 1:
Expand Down Expand Up @@ -455,7 +455,8 @@ def filter_markers_samples(self):
markers_matrix += self.get_markers_from_references()
info("Done.")
info("Removing markers / samples...")
markers_matrix = pd.DataFrame.from_records(markers_matrix, index='sample') # df with index samples and columns markers
# df with index samples and columns markers
markers_matrix = pd.DataFrame.from_records(markers_matrix, index='sample')
markers_matrix_filtered = self.filter_markers_matrix(markers_matrix, messages=True)
info("Done.")

Expand Down Expand Up @@ -626,8 +627,8 @@ def main():
strainphlan_runner = Strainphlan(args)
strainphlan_runner.run_strainphlan()
exec_time = time.time() - t0
info("Finish StrainPhlAn {} execution ({} seconds): Results are stored at \"{}\"".format(
__version__, round(exec_time, 2), args.output_dir))
info("Finish StrainPhlAn {} execution ({} seconds): Results are stored at "
"\"{}\"".format(__version__, round(exec_time, 2), args.output_dir))


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions metaphlan/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .util_fun import info, warning, error, create_folder
from .util_fun import info, warning, error, create_folder, openrt
from .parallelisation import execute_pool
from .external_exec import decompress_bz2
from .external_exec import decompress_bz2, run_command
from .database_controller import MetaphlanDatabaseController
from .phylophlan_controller import Phylophlan3Controller
from .consensus_markers import ConsensusMarker, ConsensusMarkers
5 changes: 1 addition & 4 deletions metaphlan/utils/database_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,6 @@ def resolve_database(self, database):
def resolve_index(self):
"""Resolves the name to the MPA database
Args:
database (str): the name or path of the database
Returns:
str: the resolved name to the database
"""
Expand All @@ -180,7 +177,7 @@ def get_sgbs_size(self):

def __init__(self, database, bowtie2db=None):
self.mpa_script_folder = os.path.dirname(os.path.abspath(__file__))
if bowtie2db == None:
if bowtie2db is None:
self.default_db_folder = os.path.join(
self.mpa_script_folder, "..", "metaphlan_databases")
self.default_db_folder = os.environ.get(
Expand Down
6 changes: 4 additions & 2 deletions metaphlan/utils/external_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def compose_command(params, check=False, input_file=None, database=None, output_
environment.update(new_environment)

# find string sourrunded with " and make them as one string
# TODO: we should use shlex.split for this or drop this ugly function altogether (Michal)
quotes = [j for j, e in enumerate(command_line) if e == '"']

for s, e in zip(quotes[0::2], quotes[1::2]):
Expand Down Expand Up @@ -255,13 +256,14 @@ def run_command(cmd, shell=False, **kwargs):
else:
cmd_s = cmd

r = sb.run(cmd_s, capture_output=True, **kwargs)
r = sb.run(cmd_s, shell=shell, capture_output=True, **kwargs)

if r.returncode != 0:
stdout = r.stdout
stderr = r.stderr
if 'text' not in kwargs or not kwargs['text']:
if isinstance(stdout, bytes):
stdout = stdout.decode()
if isinstance(stderr, bytes):
stderr = stderr.decode()
error('Execution failed for command', cmd)
print('stdout: ')
Expand Down
45 changes: 33 additions & 12 deletions metaphlan/utils/parallelisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
__date__ = '23 Aug 2023'


from typing import Iterable, Callable, Any
from typing import Iterable
import itertools as it

try:
from .util_fun import error
Expand Down Expand Up @@ -51,25 +52,45 @@ def parallel_execution(arguments):
terminating.set()


def execute_pool(args, nprocs):
def iterator_shorter_than(i, ln):
try:
for _ in range(ln):
next(i)
except StopIteration:
return True
return False


def execute_pool_iter(args, nprocs):
try:
terminating = Event()
with Pool(initializer=init_terminating, initargs=(terminating,), processes=nprocs) as pool:
for r in pool.imap_unordered(parallel_execution, args, chunksize=CHUNKSIZE):
yield r
except Exception as e:
error('Parallel execution fails: {}'.format(e), exit=False)
raise e


def execute_pool(args, nprocs, return_generator=False):
"""
Creates a pool for a parallelized function and returns the results of each execution as a list
Args:
args (Iterable[tuple]): tuple with the function and its arguments
nprocs (int): number of procs to use
return_generator (bool): Whether to return a non-blocking generator instead of list
Returns:
list: the list with the results of the parallel executions
"""
args = list(args)
if nprocs == 1 or len(args) <= 1: # no need to initialize pool
return [function(*a) for function, *a in args]
args, args_tmp = it.tee(args) # duplicate the iterator not to consume it
if nprocs == 1 or iterator_shorter_than(args_tmp, 2): # no need to initialize pool
gen = (function(*a) for function, *a in args)
else:
terminating = Event()
with Pool(initializer=init_terminating, initargs=(terminating,), processes=nprocs) as pool:
try:
return [_ for _ in pool.imap_unordered(parallel_execution, args, chunksize=CHUNKSIZE)]
except Exception as e:
error('Parallel execution fails: {}'.format(e), exit=False)
raise e
gen = execute_pool_iter(args, nprocs)

if return_generator:
return gen
else:
return list(gen)
3 changes: 1 addition & 2 deletions metaphlan/utils/phylophlan_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def compute_phylogeny(self, samples_markers_dir, num_samples, tmp_dir):
"""
if not self.phylophlan_configuration:
info("\tGenerating PhyloPhlAn configuration file...")
self.phylophlan_configuration = generate_phylophlan_config_file(
tmp_dir, self.get_phylophlan_configuration())
self.phylophlan_configuration = generate_phylophlan_config_file(tmp_dir, self.get_phylophlan_configuration())
info("\tDone.")
info("\tExecuting PhyloPhlAn...")
self.execute_phylophlan(samples_markers_dir, tmp_dir)
Expand Down
Loading

0 comments on commit f02065f

Please sign in to comment.