From 2fa9e42c1db08bdb45bf6c1177f2cfb46d406980 Mon Sep 17 00:00:00 2001 From: Michal Puncochar Date: Mon, 29 Apr 2024 15:03:01 +0200 Subject: [PATCH] Fix print clades only bug --- metaphlan/strainphlan.py | 45 ++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/metaphlan/strainphlan.py b/metaphlan/strainphlan.py index 249a685..75fd52d 100755 --- a/metaphlan/strainphlan.py +++ b/metaphlan/strainphlan.py @@ -440,28 +440,37 @@ def detect_clades(self): clade2markers = self.database_controller.get_clade2markers() sample2markers = {} clades_to_check = set() + all_markers = set() info('Processing samples...') consensus_markers = execute_pool(((ConsensusMarkers.from_file, sample_path) for sample_path in self.samples), nprocs=self.nprocs, return_generator=True, ordered=True) for sample_path, cm in zip(self.samples, consensus_markers): - markers = [marker.name for marker in cm.consensus_markers - if (marker.name in markers2clade and marker.breadth >= self.breadth_thres)] - sample2markers[sample_path] = markers - clades_to_check.update((markers2clade[m] for m in markers)) + sample_markers = [marker.name for marker in cm.consensus_markers + if marker.name in markers2clade and marker.breadth >= self.breadth_thres] + sample2markers[sample_path] = sample_markers + all_markers.update(sample_markers) + clades_to_check.update((markers2clade[m] for m in sample_markers)) info('Constructing the big marker matrix') - markers_matrix_big = [pd.Series({m: 1 for m in markers}, name=sample) - for sample, markers in sample2markers.items()] - markers_matrix_big = pd.concat(markers_matrix_big, axis=1).fillna(0) + all_samples = list(sample2markers.keys()) + all_markers = list(all_markers) + markers_matrix_big = np.zeros(shape=(len(all_samples), len(all_markers)), dtype=int) + for i, sample in enumerate(all_samples): + sample_markers = set(sample2markers[sample]) + for j, marker in enumerate(all_markers): + if marker in sample_markers: + markers_matrix_big[i, j] = 1 + + markers_matrix_big = pd.DataFrame(markers_matrix_big, index=all_samples, columns=all_markers) info(f'Checking {len(clades_to_check)} species') species2samples = {} for clade in clades_to_check: - markers_matrix = markers_matrix_big.reindex(clade2markers[clade]).fillna(0) + markers_for_clade = markers_matrix_big.columns.intersection(clade2markers[clade]) + markers_matrix = markers_matrix_big.loc[:, markers_for_clade] markers_matrix_filtered = self.filter_markers_matrix(markers_matrix, messages=False) - n_samples, _ = markers_matrix_filtered.shape - if n_samples >= 4: - species2samples[clade] = n_samples + if len(markers_matrix_filtered) >= 4: + species2samples[clade] = markers_matrix_filtered.index info('Done.') return species2samples @@ -469,13 +478,13 @@ def detect_clades(self): def print_clades(self): """Prints the clades detected in the reconstructed markers""" species2samples = self.detect_clades() - info('Detected clades: ') - sorted_species2samples = OrderedDict(sorted(species2samples.items(), key=lambda kv: kv[1], reverse=True)) - with open(os.path.join(self.output_dir, 'print_clades_only.tsv'), 'w') as wf: - wf.write('Clade\tNumber_of_samples\n') - for species in sorted_species2samples: - info('\t{}: in {} samples.'.format(species, sorted_species2samples[species])) - wf.write('{}\t{}\n'.format(species, sorted_species2samples[species])) + info(f'Detected {len(species2samples)} clades: ') + with open(os.path.join(self.output_dir, 'print_clades_only.tsv'), 'w') as f: + f.write('Clade\tNumber_of_samples\tSamples\n') + for species, samples in sorted(species2samples.items(), key=lambda kv: len(kv[1]), reverse=True): + info('\t{}: in {} samples.'.format(species, len(samples))) + f.write('{}\t{}\t{}\n'.format(species, len(samples), + ','.join(map(Strainphlan.sample_path_to_name, samples)))) info('Done.')