diff --git a/metaphlan/strainphlan.py b/metaphlan/strainphlan.py index 22da2ca..0f91cc0 100755 --- a/metaphlan/strainphlan.py +++ b/metaphlan/strainphlan.py @@ -48,12 +48,12 @@ def get_markers_matrix_from_samples(self): list: the list containing the samples-to-markers information """ - return execute_pool(((Strainphlan.get_matrix_for_sample, sample, self.clade_markers_names, - self.breadth_thres) for sample in self.samples), self.nprocs) + return execute_pool(((Strainphlan.get_matrix_for_sample, sample_path, self.clade_markers_names, + self.breadth_thres) for sample_path in self.samples), self.nprocs) - @staticmethod - def get_matrix_for_sample(sample_path, clade_markers, breadth_thres): + @classmethod + def get_matrix_for_sample(cls, sample_path, clade_markers, breadth_thres): """Returns the matrix with the presence / absence of the clade markers in a samples Args: @@ -65,8 +65,9 @@ def get_matrix_for_sample(sample_path, clade_markers, breadth_thres): dict: dictionary containing the sample-to-markers information as a binary matrix """ sample = ConsensusMarkers.from_file(sample_path) + sample_name = cls.sample_path_to_name(sample_path) - markers = {"sample": sample_path} + markers = {"sample_name": sample_name} markers.update({m: 0 for m in clade_markers}) markers.update({marker.name: 1 for marker in sample.consensus_markers if marker.name in clade_markers and marker.breadth >= breadth_thres}) @@ -113,17 +114,17 @@ def filter_markers_matrix(self, markers_matrix, messages=True): return markers_matrix - def copy_filtered_references(self, markers_tmp_dir, filtered_samples): + def copy_filtered_references(self, markers_tmp_dir, filtered_sample_names): """Copies the FASTA files of the filtered references to be processed by PhyloPhlAn Args: markers_tmp_dir (str): the temporary folder where to copy the reference genomes - filtered_samples (set): the set of samples after filtering + filtered_sample_names (set): the set of samples after filtering """ for reference in self.references: reference_name = self.sample_path_to_name(reference) - if reference_name in filtered_samples: + if reference_name in filtered_sample_names: reference_marker = os.path.join(self.tmp_dir, "reference_markers", f'{reference_name}.fna.bz2') copyfile(reference_marker, os.path.join(markers_tmp_dir, f"{reference_name}.fna.bz2")) @@ -139,31 +140,31 @@ def matrix_markers_to_fasta(self, markers_matrix): """ markers_tmp_dir = os.path.join(self.tmp_dir, "{}.StrainPhlAn4".format(self.clade)) create_folder(markers_tmp_dir) - filtered_samples = set(markers_matrix.index) + filtered_sample_names = set(markers_matrix.index) filtered_markers = set(markers_matrix.columns) - self.copy_filtered_references(markers_tmp_dir, filtered_samples) - execute_pool(((Strainphlan.sample_markers_to_fasta, sample, filtered_samples, filtered_markers, - self.trim_sequences, markers_tmp_dir) for sample in self.samples), self.nprocs) + self.copy_filtered_references(markers_tmp_dir, filtered_sample_names) + filtered_sample_paths = [sample_path for sample_path in self.samples + if Strainphlan.sample_path_to_name(sample_path) in filtered_sample_names] + execute_pool(((Strainphlan.sample_markers_to_fasta, sample_path, filtered_markers, self.trim_sequences, + markers_tmp_dir) for sample_path in filtered_sample_paths), self.nprocs) return markers_tmp_dir @classmethod - def sample_markers_to_fasta(cls, sample_path, filtered_samples, filtered_markers, trim_sequences, markers_tmp_dir): + def sample_markers_to_fasta(cls, sample_path, filtered_markers, trim_sequences, markers_tmp_dir): """Writes a FASTA file with the filtered clade markers of a sample Args: sample_path (str): the path to the sample filtered_markers: - filtered_samples: trim_sequences: markers_tmp_dir (str): the temporary folder were the FASTA file is written """ - if sample_path in filtered_samples: - sample_name = cls.sample_path_to_name(sample_path) - marker_output_file = os.path.join(markers_tmp_dir, f'{sample_name}.fna.bz2') - sample = ConsensusMarkers.from_file(sample_path) - sample.consensus_markers = [m for m in sample.consensus_markers if m.name in filtered_markers] - sample.to_fasta(marker_output_file, trim_ends=trim_sequences) + sample_name = cls.sample_path_to_name(sample_path) + marker_output_file = os.path.join(markers_tmp_dir, f'{sample_name}.fna.bz2') + sample = ConsensusMarkers.from_file(sample_path) + sample.consensus_markers = [m for m in sample.consensus_markers if m.name in filtered_markers] + sample.to_fasta(marker_output_file, trim_ends=trim_sequences) def get_markers_from_references(self): @@ -188,11 +189,11 @@ def get_markers_from_references(self): @classmethod - def process_reference(cls, reference_file, tmp_dir, clade_markers_file, clade_markers, trim_sequences): + def process_reference(cls, reference_path, tmp_dir, clade_markers_file, clade_markers, trim_sequences): """Processes each reference file and get a markers dictionary to add to the markers matrix Args: - reference_file (str): path to the reference file + reference_path (str): path to the reference file tmp_dir (str): the temporary folder where the BLASTn results where saved clade_markers_file (str): clade_markers (Iterable): the list with the clade markers names @@ -201,29 +202,78 @@ def process_reference(cls, reference_file, tmp_dir, clade_markers_file, clade_ma Returns: dict: the dictionary with the reference-to-markers information """ - if reference_file.endswith(".bz2"): + if reference_path.endswith(".bz2"): uncompressed_refernces_dir = os.path.join(tmp_dir, "uncompressed_references") os.makedirs(uncompressed_refernces_dir, exist_ok=True) - reference_file = decompress_bz2(reference_file, uncompressed_refernces_dir) + reference_path = decompress_bz2(reference_path, uncompressed_refernces_dir) - ext_markers = cls.extract_markers_from_genome(reference_file, clade_markers_file) + ext_markers = cls.extract_markers_from_genome(reference_path, clade_markers_file) reference_markers_dir = os.path.join(tmp_dir, "reference_markers") os.makedirs(reference_markers_dir, exist_ok=True) consensus_markers = ConsensusMarkers([ConsensusMarker(m, s) for m, s in ext_markers.items()]) - reference_name = cls.sample_path_to_name(reference_file) + reference_name = cls.sample_path_to_name(reference_path) consensus_markers.to_fasta(os.path.join(reference_markers_dir, f'{reference_name}.fna.bz2'), trim_ends=trim_sequences) - markers_matrix = {'sample': reference_file} + markers_matrix = {'sample_name': reference_name} markers_matrix.update({m: int(m in ext_markers) for m in clade_markers}) return markers_matrix @staticmethod - def extract_markers_from_genome(reference_file, clade_markers_file): + def extract_with_btop(sseq, btop, qstart, qend, qlen, sstart, send): + btop = re.split(r'(\d+|\D{2})', btop)[1::2] # blast trace-back operations + + assert qend >= qstart # query should be forward + strand = 1 if send >= sstart else -1 # whether reverse-complemented + + qi = qstart - 1 + si = sstart - 1 + ext_s = '-' * qi + for op in btop: + if op.isnumeric(): # match + op = int(op) + b = si + strand * op + if b == -1 and strand == -1: # we should go to the beginning, but -1 gets interpreted as the end + b = 0 + ext_s += sseq[si: b: strand] + sseq[0] + else: + ext_s += sseq[si: b: strand] + + qi += op + si += strand * op + else: + if strand == -1: + op = str(Seq.Seq(op).complement()) + + if op[0] == '-': # query gap + si += strand + elif op[1] == '-': # subject gap + ext_s += '-' + qi += 1 + else: + qi += 1 + si += strand + ext_s += op[1] + + ext_s += '-' * (qlen - qend) + + # Check we parsed everything correctly + assert qi == qend + assert si + 1 - strand == send + assert len(ext_s) == qlen + + if strand == -1: + ext_s = str(Seq.Seq(ext_s).complement()) + + return ext_s + + + @classmethod + def extract_markers_from_genome(cls, reference_file, clade_markers_file): """ Args: @@ -248,48 +298,39 @@ def extract_markers_from_genome(reference_file, clade_markers_file): r = run_command(cmd, input=input_file_data, text=True) # load the blast output - df = pd.read_csv(io.StringIO(r.stdout), sep='\t', names=columns.split(' ')) + df = pd.read_csv(io.StringIO(r.stdout), sep='\t', names=columns.split(' '), dtype={'btop': str}) - ext_markers = {} - for _, row in df.iterrows(): + df['ext_s'] = '' + for idx, row in df.iterrows(): sseq = str(input_seqs[row['sseqid']].seq) - btop = re.split(r'(\d+|\D{2})', row['btop'])[1::2] # blast trace-back operations - - assert row['qend'] >= row['qstart'] - strand = 1 if row['send'] >= row['sstart'] else -1 # whether reverse-complemented - - qi = row['qstart'] - 1 - si = row['sstart'] - 1 - ext_s = '-' * qi - for op in btop: - if op.isnumeric(): - op = int(op) - ext_s += sseq[si: si + strand * op: strand] - qi += op - si += strand * op - else: - if strand == -1: - op = str(Seq.Seq(op).complement()) + ext_s = cls.extract_with_btop(sseq, **row[['btop', 'qstart', 'qend', 'qlen', 'sstart', 'send']]) + df.loc[idx, 'ext_s'] = ext_s - if op[0] == '-': # query gap - si += strand - elif op[1] == '-': # subject gap - ext_s += '-' - qi += 1 - else: - qi += 1 - si += strand - ext_s += op[1] + def segments_overlap(sa, sb): + """Intervals are open at the right, e.g. s = [s0, s1)""" + if sa[0] >= sb[1] or sb[0] >= sa[1]: + return False + return True - ext_s += '-' * (row['qlen'] - row['qend']) - - assert qi == row['qend'] - assert len(ext_s) == row['qlen'] - - if strand == -1: - ext_s = str(Seq.Seq(ext_s).complement()) - - ext_markers[row['qseqid']] = ext_s + ext_markers = {} + for query, idx in df.groupby('qseqid').groups.items(): # for each query/marker + df_query = df.loc[idx] + best_ref = df_query.iloc[0]['sseqid'] # take the best hit and consider only that contig/reference + df_query = df_query.query(f'sseqid=="{best_ref}"') + covered_regions = [] + ext_s = ['-'] * df_query.iloc[0]['qlen'] + for _, row in df_query.iterrows(): + reg = (row['qstart'], row['qend']) + if any(segments_overlap(reg, cr) for cr in covered_regions): + continue + + # non-overlaping hit => expand + covered_regions.append(reg) + assert all('-' in [a, b] for a, b in zip(ext_s, row['ext_s'])) # make sure they really don't overlap + ext_s = [a if b == '-' else b for a, b in zip(ext_s, row['ext_s'])] + ext_s = ''.join(ext_s) + + ext_markers[query] = ext_s return ext_markers @@ -324,7 +365,6 @@ def calculate_polymorphic_rates(self): def write_info(self, markers_matrix): """Writes the information file for the execution""" - filtered_names = [self.sample_path_to_name(sample) for sample in markers_matrix.index] with open(os.path.join(self.output_dir, "{}.info".format(self.clade)), 'w') as info_file: info_file.write("Clade: {}\n".format(self.clade)) info_file.write("Number of samples: {}\n".format(len(self.samples))) @@ -342,10 +382,11 @@ def write_info(self, markers_matrix): f"{self.sample_with_n_markers_after_filt_perc}\n") info_file.write(f"\tMinimum percentage of samples to keep a marker: {self.marker_in_n_samples_perc}\n") info_file.write("Number of markers selected after filtering: {}\n".format(len(markers_matrix.columns))) - n_samples = len([sample for sample in self.samples if sample in markers_matrix.index]) + n_samples = len([sample for sample in self.samples + if Strainphlan.sample_path_to_name(sample) in markers_matrix.index]) info_file.write("Number of samples after filtering: {}\n".format(n_samples)) n_refs = len([reference for reference in self.references - if self.sample_path_to_name(reference) in filtered_names]) + if Strainphlan.sample_path_to_name(reference) in markers_matrix.index]) info_file.write("Number of references after filtering: {}\n".format(n_refs)) info_file.write("PhyloPhlan phylogenetic precision mode: {}\n".format(self.phylophlan_mode)) info_file.write("Number of processes used: {}\n".format(self.nprocs)) @@ -456,7 +497,7 @@ def filter_markers_samples(self): info("Done.") info("Removing markers / samples...") # df with index samples and columns markers - markers_matrix = pd.DataFrame.from_records(markers_matrix, index='sample') + markers_matrix = pd.DataFrame.from_records(markers_matrix, index='sample_name') markers_matrix_filtered = self.filter_markers_matrix(markers_matrix, messages=True) info("Done.") @@ -501,7 +542,7 @@ def __init__(self, args): self.database_controller = MetaphlanDatabaseController(args.database) self.clade_markers_file = args.clade_markers self.samples = args.samples - self.references = args.references if not args.print_clades_only else [] + self.references = args.references self.clade = args.clade self.output_dir = args.output_dir self.trim_sequences = args.trim_sequences @@ -615,6 +656,11 @@ def check_params(args): if not os.path.exists(r): error('The reference file \"{}\" does not exist'.format(r), exit=True) + sample_names = [Strainphlan.sample_path_to_name(s) for s in args.samples] + ref_names = [Strainphlan.sample_path_to_name(s) for s in args.references] + if len(sample_names) + len(ref_names) != len(set(sample_names + ref_names)): + error('Some sample or reference names are duplicated') + def main(): t0 = time.time()