Skip to content

Commit

Permalink
Fix print clades only bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PuncocharM committed Apr 29, 2024
1 parent 4be0d0a commit 2fa9e42
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions metaphlan/strainphlan.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,42 +440,51 @@ def detect_clades(self):
clade2markers = self.database_controller.get_clade2markers()
sample2markers = {}
clades_to_check = set()
all_markers = set()
info('Processing samples...')
consensus_markers = execute_pool(((ConsensusMarkers.from_file, sample_path) for sample_path in self.samples),
nprocs=self.nprocs, return_generator=True, ordered=True)
for sample_path, cm in zip(self.samples, consensus_markers):
markers = [marker.name for marker in cm.consensus_markers
if (marker.name in markers2clade and marker.breadth >= self.breadth_thres)]
sample2markers[sample_path] = markers
clades_to_check.update((markers2clade[m] for m in markers))
sample_markers = [marker.name for marker in cm.consensus_markers
if marker.name in markers2clade and marker.breadth >= self.breadth_thres]
sample2markers[sample_path] = sample_markers
all_markers.update(sample_markers)
clades_to_check.update((markers2clade[m] for m in sample_markers))

info('Constructing the big marker matrix')
markers_matrix_big = [pd.Series({m: 1 for m in markers}, name=sample)
for sample, markers in sample2markers.items()]
markers_matrix_big = pd.concat(markers_matrix_big, axis=1).fillna(0)
all_samples = list(sample2markers.keys())
all_markers = list(all_markers)
markers_matrix_big = np.zeros(shape=(len(all_samples), len(all_markers)), dtype=int)
for i, sample in enumerate(all_samples):
sample_markers = set(sample2markers[sample])
for j, marker in enumerate(all_markers):
if marker in sample_markers:
markers_matrix_big[i, j] = 1

markers_matrix_big = pd.DataFrame(markers_matrix_big, index=all_samples, columns=all_markers)

info(f'Checking {len(clades_to_check)} species')
species2samples = {}
for clade in clades_to_check:
markers_matrix = markers_matrix_big.reindex(clade2markers[clade]).fillna(0)
markers_for_clade = markers_matrix_big.columns.intersection(clade2markers[clade])
markers_matrix = markers_matrix_big.loc[:, markers_for_clade]
markers_matrix_filtered = self.filter_markers_matrix(markers_matrix, messages=False)
n_samples, _ = markers_matrix_filtered.shape
if n_samples >= 4:
species2samples[clade] = n_samples
if len(markers_matrix_filtered) >= 4:
species2samples[clade] = markers_matrix_filtered.index
info('Done.')
return species2samples


def print_clades(self):
"""Prints the clades detected in the reconstructed markers"""
species2samples = self.detect_clades()
info('Detected clades: ')
sorted_species2samples = OrderedDict(sorted(species2samples.items(), key=lambda kv: kv[1], reverse=True))
with open(os.path.join(self.output_dir, 'print_clades_only.tsv'), 'w') as wf:
wf.write('Clade\tNumber_of_samples\n')
for species in sorted_species2samples:
info('\t{}: in {} samples.'.format(species, sorted_species2samples[species]))
wf.write('{}\t{}\n'.format(species, sorted_species2samples[species]))
info(f'Detected {len(species2samples)} clades: ')
with open(os.path.join(self.output_dir, 'print_clades_only.tsv'), 'w') as f:
f.write('Clade\tNumber_of_samples\tSamples\n')
for species, samples in sorted(species2samples.items(), key=lambda kv: len(kv[1]), reverse=True):
info('\t{}: in {} samples.'.format(species, len(samples)))
f.write('{}\t{}\t{}\n'.format(species, len(samples),
','.join(map(Strainphlan.sample_path_to_name, samples))))
info('Done.')


Expand Down

0 comments on commit 2fa9e42

Please sign in to comment.