Skip to content

Commit

Permalink
Blasting in references: merge non-overlapping hits and other fixes an…
Browse files Browse the repository at this point in the history
…d code improvements
  • Loading branch information
PuncocharM committed Oct 11, 2023
1 parent f02065f commit bc7d664
Showing 1 changed file with 117 additions and 71 deletions.
188 changes: 117 additions & 71 deletions metaphlan/strainphlan.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def get_markers_matrix_from_samples(self):
list: the list containing the samples-to-markers information
"""

return execute_pool(((Strainphlan.get_matrix_for_sample, sample, self.clade_markers_names,
self.breadth_thres) for sample in self.samples), self.nprocs)
return execute_pool(((Strainphlan.get_matrix_for_sample, sample_path, self.clade_markers_names,
self.breadth_thres) for sample_path in self.samples), self.nprocs)


@staticmethod
def get_matrix_for_sample(sample_path, clade_markers, breadth_thres):
@classmethod
def get_matrix_for_sample(cls, sample_path, clade_markers, breadth_thres):
"""Returns the matrix with the presence / absence of the clade markers in a samples
Args:
Expand All @@ -65,8 +65,9 @@ def get_matrix_for_sample(sample_path, clade_markers, breadth_thres):
dict: dictionary containing the sample-to-markers information as a binary matrix
"""
sample = ConsensusMarkers.from_file(sample_path)
sample_name = cls.sample_path_to_name(sample_path)

markers = {"sample": sample_path}
markers = {"sample_name": sample_name}
markers.update({m: 0 for m in clade_markers})
markers.update({marker.name: 1 for marker in sample.consensus_markers
if marker.name in clade_markers and marker.breadth >= breadth_thres})
Expand Down Expand Up @@ -113,17 +114,17 @@ def filter_markers_matrix(self, markers_matrix, messages=True):
return markers_matrix


def copy_filtered_references(self, markers_tmp_dir, filtered_samples):
def copy_filtered_references(self, markers_tmp_dir, filtered_sample_names):
"""Copies the FASTA files of the filtered references to be processed by PhyloPhlAn
Args:
markers_tmp_dir (str): the temporary folder where to copy the reference genomes
filtered_samples (set): the set of samples after filtering
filtered_sample_names (set): the set of samples after filtering
"""
for reference in self.references:
reference_name = self.sample_path_to_name(reference)

if reference_name in filtered_samples:
if reference_name in filtered_sample_names:
reference_marker = os.path.join(self.tmp_dir, "reference_markers", f'{reference_name}.fna.bz2')
copyfile(reference_marker, os.path.join(markers_tmp_dir, f"{reference_name}.fna.bz2"))

Expand All @@ -139,31 +140,31 @@ def matrix_markers_to_fasta(self, markers_matrix):
"""
markers_tmp_dir = os.path.join(self.tmp_dir, "{}.StrainPhlAn4".format(self.clade))
create_folder(markers_tmp_dir)
filtered_samples = set(markers_matrix.index)
filtered_sample_names = set(markers_matrix.index)
filtered_markers = set(markers_matrix.columns)
self.copy_filtered_references(markers_tmp_dir, filtered_samples)
execute_pool(((Strainphlan.sample_markers_to_fasta, sample, filtered_samples, filtered_markers,
self.trim_sequences, markers_tmp_dir) for sample in self.samples), self.nprocs)
self.copy_filtered_references(markers_tmp_dir, filtered_sample_names)
filtered_sample_paths = [sample_path for sample_path in self.samples
if Strainphlan.sample_path_to_name(sample_path) in filtered_sample_names]
execute_pool(((Strainphlan.sample_markers_to_fasta, sample_path, filtered_markers, self.trim_sequences,
markers_tmp_dir) for sample_path in filtered_sample_paths), self.nprocs)
return markers_tmp_dir


@classmethod
def sample_markers_to_fasta(cls, sample_path, filtered_samples, filtered_markers, trim_sequences, markers_tmp_dir):
def sample_markers_to_fasta(cls, sample_path, filtered_markers, trim_sequences, markers_tmp_dir):
"""Writes a FASTA file with the filtered clade markers of a sample
Args:
sample_path (str): the path to the sample
filtered_markers:
filtered_samples:
trim_sequences:
markers_tmp_dir (str): the temporary folder were the FASTA file is written
"""
if sample_path in filtered_samples:
sample_name = cls.sample_path_to_name(sample_path)
marker_output_file = os.path.join(markers_tmp_dir, f'{sample_name}.fna.bz2')
sample = ConsensusMarkers.from_file(sample_path)
sample.consensus_markers = [m for m in sample.consensus_markers if m.name in filtered_markers]
sample.to_fasta(marker_output_file, trim_ends=trim_sequences)
sample_name = cls.sample_path_to_name(sample_path)
marker_output_file = os.path.join(markers_tmp_dir, f'{sample_name}.fna.bz2')
sample = ConsensusMarkers.from_file(sample_path)
sample.consensus_markers = [m for m in sample.consensus_markers if m.name in filtered_markers]
sample.to_fasta(marker_output_file, trim_ends=trim_sequences)


def get_markers_from_references(self):
Expand All @@ -188,11 +189,11 @@ def get_markers_from_references(self):


@classmethod
def process_reference(cls, reference_file, tmp_dir, clade_markers_file, clade_markers, trim_sequences):
def process_reference(cls, reference_path, tmp_dir, clade_markers_file, clade_markers, trim_sequences):
"""Processes each reference file and get a markers dictionary to add to the markers matrix
Args:
reference_file (str): path to the reference file
reference_path (str): path to the reference file
tmp_dir (str): the temporary folder where the BLASTn results where saved
clade_markers_file (str):
clade_markers (Iterable): the list with the clade markers names
Expand All @@ -201,29 +202,78 @@ def process_reference(cls, reference_file, tmp_dir, clade_markers_file, clade_ma
Returns:
dict: the dictionary with the reference-to-markers information
"""
if reference_file.endswith(".bz2"):
if reference_path.endswith(".bz2"):
uncompressed_refernces_dir = os.path.join(tmp_dir, "uncompressed_references")
os.makedirs(uncompressed_refernces_dir, exist_ok=True)
reference_file = decompress_bz2(reference_file, uncompressed_refernces_dir)
reference_path = decompress_bz2(reference_path, uncompressed_refernces_dir)

ext_markers = cls.extract_markers_from_genome(reference_file, clade_markers_file)
ext_markers = cls.extract_markers_from_genome(reference_path, clade_markers_file)

reference_markers_dir = os.path.join(tmp_dir, "reference_markers")
os.makedirs(reference_markers_dir, exist_ok=True)

consensus_markers = ConsensusMarkers([ConsensusMarker(m, s) for m, s in ext_markers.items()])
reference_name = cls.sample_path_to_name(reference_file)
reference_name = cls.sample_path_to_name(reference_path)
consensus_markers.to_fasta(os.path.join(reference_markers_dir, f'{reference_name}.fna.bz2'),
trim_ends=trim_sequences)

markers_matrix = {'sample': reference_file}
markers_matrix = {'sample_name': reference_name}
markers_matrix.update({m: int(m in ext_markers) for m in clade_markers})

return markers_matrix


@staticmethod
def extract_markers_from_genome(reference_file, clade_markers_file):
def extract_with_btop(sseq, btop, qstart, qend, qlen, sstart, send):
btop = re.split(r'(\d+|\D{2})', btop)[1::2] # blast trace-back operations

assert qend >= qstart # query should be forward
strand = 1 if send >= sstart else -1 # whether reverse-complemented

qi = qstart - 1
si = sstart - 1
ext_s = '-' * qi
for op in btop:
if op.isnumeric(): # match
op = int(op)
b = si + strand * op
if b == -1 and strand == -1: # we should go to the beginning, but -1 gets interpreted as the end
b = 0
ext_s += sseq[si: b: strand] + sseq[0]
else:
ext_s += sseq[si: b: strand]

qi += op
si += strand * op
else:
if strand == -1:
op = str(Seq.Seq(op).complement())

if op[0] == '-': # query gap
si += strand
elif op[1] == '-': # subject gap
ext_s += '-'
qi += 1
else:
qi += 1
si += strand
ext_s += op[1]

ext_s += '-' * (qlen - qend)

# Check we parsed everything correctly
assert qi == qend
assert si + 1 - strand == send
assert len(ext_s) == qlen

if strand == -1:
ext_s = str(Seq.Seq(ext_s).complement())

return ext_s


@classmethod
def extract_markers_from_genome(cls, reference_file, clade_markers_file):
"""
Args:
Expand All @@ -248,48 +298,39 @@ def extract_markers_from_genome(reference_file, clade_markers_file):
r = run_command(cmd, input=input_file_data, text=True)

# load the blast output
df = pd.read_csv(io.StringIO(r.stdout), sep='\t', names=columns.split(' '))
df = pd.read_csv(io.StringIO(r.stdout), sep='\t', names=columns.split(' '), dtype={'btop': str})

ext_markers = {}
for _, row in df.iterrows():
df['ext_s'] = ''
for idx, row in df.iterrows():
sseq = str(input_seqs[row['sseqid']].seq)
btop = re.split(r'(\d+|\D{2})', row['btop'])[1::2] # blast trace-back operations

assert row['qend'] >= row['qstart']
strand = 1 if row['send'] >= row['sstart'] else -1 # whether reverse-complemented

qi = row['qstart'] - 1
si = row['sstart'] - 1
ext_s = '-' * qi
for op in btop:
if op.isnumeric():
op = int(op)
ext_s += sseq[si: si + strand * op: strand]
qi += op
si += strand * op
else:
if strand == -1:
op = str(Seq.Seq(op).complement())
ext_s = cls.extract_with_btop(sseq, **row[['btop', 'qstart', 'qend', 'qlen', 'sstart', 'send']])
df.loc[idx, 'ext_s'] = ext_s

if op[0] == '-': # query gap
si += strand
elif op[1] == '-': # subject gap
ext_s += '-'
qi += 1
else:
qi += 1
si += strand
ext_s += op[1]
def segments_overlap(sa, sb):
"""Intervals are open at the right, e.g. s = [s0, s1)"""
if sa[0] >= sb[1] or sb[0] >= sa[1]:
return False
return True

ext_s += '-' * (row['qlen'] - row['qend'])

assert qi == row['qend']
assert len(ext_s) == row['qlen']

if strand == -1:
ext_s = str(Seq.Seq(ext_s).complement())

ext_markers[row['qseqid']] = ext_s
ext_markers = {}
for query, idx in df.groupby('qseqid').groups.items(): # for each query/marker
df_query = df.loc[idx]
best_ref = df_query.iloc[0]['sseqid'] # take the best hit and consider only that contig/reference
df_query = df_query.query(f'sseqid=="{best_ref}"')
covered_regions = []
ext_s = ['-'] * df_query.iloc[0]['qlen']
for _, row in df_query.iterrows():
reg = (row['qstart'], row['qend'])
if any(segments_overlap(reg, cr) for cr in covered_regions):
continue

# non-overlaping hit => expand
covered_regions.append(reg)
assert all('-' in [a, b] for a, b in zip(ext_s, row['ext_s'])) # make sure they really don't overlap
ext_s = [a if b == '-' else b for a, b in zip(ext_s, row['ext_s'])]
ext_s = ''.join(ext_s)

ext_markers[query] = ext_s

return ext_markers

Expand Down Expand Up @@ -324,7 +365,6 @@ def calculate_polymorphic_rates(self):

def write_info(self, markers_matrix):
"""Writes the information file for the execution"""
filtered_names = [self.sample_path_to_name(sample) for sample in markers_matrix.index]
with open(os.path.join(self.output_dir, "{}.info".format(self.clade)), 'w') as info_file:
info_file.write("Clade: {}\n".format(self.clade))
info_file.write("Number of samples: {}\n".format(len(self.samples)))
Expand All @@ -342,10 +382,11 @@ def write_info(self, markers_matrix):
f"{self.sample_with_n_markers_after_filt_perc}\n")
info_file.write(f"\tMinimum percentage of samples to keep a marker: {self.marker_in_n_samples_perc}\n")
info_file.write("Number of markers selected after filtering: {}\n".format(len(markers_matrix.columns)))
n_samples = len([sample for sample in self.samples if sample in markers_matrix.index])
n_samples = len([sample for sample in self.samples
if Strainphlan.sample_path_to_name(sample) in markers_matrix.index])
info_file.write("Number of samples after filtering: {}\n".format(n_samples))
n_refs = len([reference for reference in self.references
if self.sample_path_to_name(reference) in filtered_names])
if Strainphlan.sample_path_to_name(reference) in markers_matrix.index])
info_file.write("Number of references after filtering: {}\n".format(n_refs))
info_file.write("PhyloPhlan phylogenetic precision mode: {}\n".format(self.phylophlan_mode))
info_file.write("Number of processes used: {}\n".format(self.nprocs))
Expand Down Expand Up @@ -456,7 +497,7 @@ def filter_markers_samples(self):
info("Done.")
info("Removing markers / samples...")
# df with index samples and columns markers
markers_matrix = pd.DataFrame.from_records(markers_matrix, index='sample')
markers_matrix = pd.DataFrame.from_records(markers_matrix, index='sample_name')
markers_matrix_filtered = self.filter_markers_matrix(markers_matrix, messages=True)
info("Done.")

Expand Down Expand Up @@ -501,7 +542,7 @@ def __init__(self, args):
self.database_controller = MetaphlanDatabaseController(args.database)
self.clade_markers_file = args.clade_markers
self.samples = args.samples
self.references = args.references if not args.print_clades_only else []
self.references = args.references
self.clade = args.clade
self.output_dir = args.output_dir
self.trim_sequences = args.trim_sequences
Expand Down Expand Up @@ -615,6 +656,11 @@ def check_params(args):
if not os.path.exists(r):
error('The reference file \"{}\" does not exist'.format(r), exit=True)

sample_names = [Strainphlan.sample_path_to_name(s) for s in args.samples]
ref_names = [Strainphlan.sample_path_to_name(s) for s in args.references]
if len(sample_names) + len(ref_names) != len(set(sample_names + ref_names)):
error('Some sample or reference names are duplicated')


def main():
t0 = time.time()
Expand Down

0 comments on commit bc7d664

Please sign in to comment.