Skip to content

Commit

Permalink
add --merged_profiles option to fix_relab
Browse files Browse the repository at this point in the history
  • Loading branch information
Cengoni committed May 21, 2024
1 parent 3080d94 commit 102adcb
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 33 deletions.
108 changes: 76 additions & 32 deletions metaphlan/utils/fix_relab_mpa4.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
try:
from .util_fun import info, error, warning
except ImportError:
from util_fun import info, error, warning
from util_fun import info, error, warning, openrt
import argparse as ap
import numpy as np

script_install_folder = os.path.dirname(os.path.abspath(__file__))
OCT22_FIXES=os.path.join(script_install_folder,'oct22_fix_tax.tsv')
Expand All @@ -21,10 +22,18 @@ def read_params():
Returns:
namespace: The populated namespace with the command line arguments
"""
p = ap.ArgumentParser(formatter_class=ap.RawTextHelpFormatter, add_help=False)
p = ap.ArgumentParser(formatter_class=ap.RawTextHelpFormatter, add_help=False,
description = "\nThis script allows you to fix some taxonomic inconsistencies "
"present in mpa_vOct22_CHOCOPhlAnSGB_202212 and mpa_vJun23_CHOCOPhlAnSGB_202307.\n" +
"The output profile will have fixed taxonomies and renormalized relative abundances.\n")

requiredNamed = p.add_argument_group('required arguments')
requiredNamed.add_argument('--input', type=str, default=None, help="The path to the input profile")
requiredNamed.add_argument('--output', type=str, default=None, help="The path to the output profile")
requiredNamed.add_argument('-i','--input', type=str, default=None, help="The path to the input profile")
requiredNamed.add_argument('-o','--output', type=str, default=None, help="The path to the output profile")
p.add_argument('--merged_profiles', action='store_true', default=False, help=("To specify when running the script on profiles that were already merged with merge_metaphlan_tables.py"))
p.add_argument("-h", "--help", action="help", help="show this help message and exit")


return p.parse_args()

def read_oct22_fixes(file):
Expand Down Expand Up @@ -54,16 +63,38 @@ def check_params(args):
if not args.output:
error('--output must be specified', exit=True)

def fix_relab_mpa4(input, output):
def assign_higher_taxonomic_levels(taxa_levs, merged):
if not merged:
for i in range(1, 8):
j = i + 1
for ss in taxa_levs[-i]:
gg = ss.replace('|{}'.format(ss.split('|')[-1]), '')
gg_n = '|'.join(taxa_levs[-i][ss][0].split('|')[:-1])
if gg not in taxa_levs[-j]:
taxa_levs[-j][gg] = [gg_n, taxa_levs[-i][ss][1], '']
else:
taxa_levs[-j][gg][1] += taxa_levs[-i][ss][1]
else:
for i in range(1, 8):
j = i + 1
for ss in taxa_levs[-i]:
gg = ss.replace('|{}'.format(ss.split('|')[-1]), '')
if gg not in taxa_levs[-j]:
taxa_levs[-j][gg] = taxa_levs[-i][ss]
else:
taxa_levs[-j][gg] = np.add(taxa_levs[-j][gg], taxa_levs[-i][ss])
return taxa_levs

def fix_relab_mpa4(input, output, merged):
taxa_levs = [{},{},{},{},{},{},{},{}]
with open(input, 'r') as rf:
with openrt(input) as rf:
with open(output, 'w') as wf:
for line in rf:
if line.startswith('#mpa_v'):
release = line.strip()[1:]
line = '_'.join(line.split('_')[:-1])
wf.write('{}_202403\n'.format(line.strip()))
elif line.startswith('#') or line.startswith('UNCLASSIFIED'):
elif line.startswith('#') or line.startswith('UNCLASSIFIED') or line.startswith('clade_name'):
wf.write(line)
else:
if 't__' in line:
Expand All @@ -77,30 +108,43 @@ def fix_relab_mpa4(input, output):
elif release == 'mpa_vOct22_CHOCOPhlAnSGB_202212':
line = line.strip().split('\t')
if line[0] in oct_fixes:
line[0],line[1] = oct_fixes[line[0]]

taxa_levs[-1][line[0]] = [line[1], float(line[2]), line[3] if len(line)==4 else '']

for i in range(1,8):
j = i+1
for ss in taxa_levs[-i]:
gg = ss.replace('|{}'.format(ss.split('|')[-1]), '')
gg_n = '|'.join(taxa_levs[-i][ss][0].split('|')[:-1])
if gg not in taxa_levs[-j]:
taxa_levs[-j][gg] = [gg_n, taxa_levs[-i][ss][1], '']
else:
taxa_levs[-j][gg][1] += taxa_levs[-i][ss][1]

sum_level = dict()
for level in range(len(taxa_levs)):
sum_level[level] = 0
for tax in taxa_levs[level]:
sum_level[level] += taxa_levs[level][tax][1]

for level in range(len(taxa_levs)):
for tax in taxa_levs[level]:
taxa_levs[level][tax][1] = round(100 * taxa_levs[level][tax][1]/sum_level[level], 5)
wf.write(tax + '\t' + '\t'.join([str(x) for x in taxa_levs[level][tax]]) + '\n')
if not merged:
line[0],line[1] = oct_fixes[line[0]]
else:
line[0] = oct_fixes[line[0]][0]

if not merged:
taxa_levs[-1][line[0]] = [line[1], float(line[2]), line[3] if len(line)==4 else '']
else:
taxa_levs[-1][line[0]] = [float(l) for l in line[1:]]
ncols = len(line)-1

taxa_levs = assign_higher_taxonomic_levels(taxa_levs, merged)

# normalize the relative abundances and write to file
if not merged:
sum_level = dict()
for level in range(len(taxa_levs)):
sum_level[level] = 0
for tax in taxa_levs[level]:
sum_level[level] += taxa_levs[level][tax][1]

for level in range(len(taxa_levs)):
for tax in taxa_levs[level]:
taxa_levs[level][tax][1] = round(100 * taxa_levs[level][tax][1]/sum_level[level], 5)
wf.write(tax + '\t' + '\t'.join([str(x) for x in taxa_levs[level][tax]]) + '\n')
else:
sum_level = dict()
for level in range(len(taxa_levs)):
sum_level[level] = [0]*ncols
for tax in taxa_levs[level]:
sum_level[level] = np.add(sum_level[level], taxa_levs[level][tax])

for level in range(len(taxa_levs)):
for tax in taxa_levs[level]:
for n in range(len(taxa_levs[level][tax])):
taxa_levs[level][tax][n] = round(100 * taxa_levs[level][tax][n]/sum_level[level][n], 5)
wf.write(tax + '\t' + '\t'.join([str(x) for x in taxa_levs[level][tax]]) + '\n')


def main():
Expand All @@ -110,7 +154,7 @@ def main():
info("Start fixing profile")
check_params(args)
oct_fixes = read_oct22_fixes(OCT22_FIXES)
fix_relab_mpa4(args.input, args.output)
fix_relab_mpa4(args.input, args.output, args.merged_profiles)
exec_time = time.time() - t0
info("Finish fixing profile ({} seconds)".format(round(exec_time, 2)))

Expand Down
2 changes: 1 addition & 1 deletion metaphlan/utils/util_fun.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def error(message, init_new_line=True, exit=False, exit_value=1):
sys.stdout.write('\n')

if exit:
sys.stderr.write('{}: Stop StrainPhlAn execution.\n'.format(
sys.stderr.write('{}: Stop execution.\n'.format(
time.ctime(int(time.time()))))
sys.exit(exit_value)

Expand Down

0 comments on commit 102adcb

Please sign in to comment.