Skip to content

Commit

Permalink
Merge pull request #46 from collaborativebioinformatics/parallelization
Browse files Browse the repository at this point in the history
Parallelization and refactoring
  • Loading branch information
wdecoster authored Nov 3, 2021
2 parents 2b609e9 + 233f216 commit ef7cd87
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 128 deletions.
193 changes: 66 additions & 127 deletions STRdust/STRdust.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,19 @@
import tempfile
from argparse import ArgumentParser

import sys
import os

import shutil
from concurrent.futures import ProcessPoolExecutor
import pysam
import re
import subprocess
import logging
import pandas as pd
from .version import __version__

from spoa import poa
from itertools import groupby


logger = logging.getLogger()


def _validate_path(path):
try:
os.makedirs(path, exist_ok=True)
temp_dir_path = tempfile.mkdtemp(dir=path)
os.rmdir(temp_dir_path)
return True
except OSError:
return False


def _enable_logging(log_file, debug, overwrite):
"""
Turns on logging, sets debug levels and assigns a log file
"""
log_formatter = logging.Formatter("[%(asctime)s] %(name)s: %(levelname)s: "
"%(message)s", "%Y-%m-%d %H:%M:%S")
console_formatter = logging.Formatter("[%(asctime)s] %(levelname)s: "
"%(message)s", "%Y-%m-%d %H:%M:%S")
from itertools import groupby, repeat

console_log = logging.StreamHandler()
console_log.setFormatter(console_formatter)

if not debug:
console_log.setLevel(logging.INFO)

if overwrite:
open(log_file, "w").close()

file_handler = logging.FileHandler(log_file, mode="a")
file_handler.setFormatter(log_formatter)

logger.setLevel(logging.DEBUG)
logger.addHandler(console_log)
logger.addHandler(file_handler)


class InputException(Exception):
pass
import STRdust.utils as utils


class Insertion(object):
Expand Down Expand Up @@ -89,98 +47,56 @@ def is_overlapping(self, other, distance=15):
return all(condition)


def _check_bam_files(bam_file):
"""
Check existance of input files and generate index file if it is absent
:param bam_file: phased bam file with/without bai file
"""

if not os.path.exists(bam_file):
raise InputException(f"Can't open {bam_file}")

samfile = pysam.AlignmentFile(bam_file, "rb")
if not samfile.has_index():
logging.info("Input bam file does not have index file (.bai). Generating now.")
pysam.index(bam_file)


def main():
args = get_args()

# Creating output directory
if not os.path.isdir(args.out_dir):
os.mkdir(args.out_dir)
else:
path_to_dir = os.path.join(args.out_dir, "test")
if not _validate_path(path_to_dir):
sys.exit(f"Problem with writing permissions in output directory. {path_to_dir}\n")
args.out_dir = os.path.abspath(args.out_dir)

# Set up logging
log_file = os.path.join(args.out_dir, "STRdust.log")
_enable_logging(log_file, args.debug, overwrite=True)
args.out_dir = utils.create_output_directory(args.out_dir)
utils._enable_logging(args.out_dir, args.debug, overwrite=True)

# Check input files
try:
_check_bam_files(args.bam)
except InputException as err:
logger.error(f"Problem with input files: {err}")

# Set up temporary directories
ins_dir = os.path.join(args.out_dir, "chrs_ins_tmp")
if not os.path.isdir(ins_dir):
os.mkdir(ins_dir)
utils._check_bam_files(args.bam)
except utils.InputException as err:
logging.error(f"Problem with input files: {err}")

vcf_dir = os.path.join(args.out_dir, "chrs_vcf_tmp") # TODO for parallel implementation
if not os.path.isdir(vcf_dir):
os.mkdir(vcf_dir)
args.ins_dir, args.vcf_dir = utils.setup_temp_dirs(args.out_dir)

dust = {}
vcf_final_file = os.path.join(args.out_dir, "strdust-list.vcf")
if args.region:
insertions = extract_insertions(args.bam, args.region, minlen=15,
mapq=10, merge_distance=args.distance, flank_distance=50)
insertions = merge_overlapping_insertions(sorted(insertions), merge_distance=args.distance)

ins_chr_file = os.path.join(ins_dir, "ins_region.fa")
write_ins_file(insertions, ins_chr_file)

mreps_dict = parse_mreps_result(run_mreps(ins_chr_file, args.mreps_res))
dust.update(mreps_dict)
if not args.save_temp:
os.remove(ins_chr_file)
temporary_files = [run(args, args.region)]
else:
for chrom in pysam.AlignmentFile(args.bam, "rb").references:
logging.info(f"-- Start processing chromosome: {chrom} --")

insertions = extract_insertions(args.bam, chrom, minlen=15,
mapq=10, merge_distance=args.distance,
flank_distance=50)
insertions = merge_overlapping_insertions(
sorted(insertions), merge_distance=args.distance)
chromosomes = [c for c in pysam.AlignmentFile(args.bam, "rb").references if '_' not in c]
with ProcessPoolExecutor(max_workers=args.threads) as executor:
temporary_files = [f for f in executor.map(run, repeat(args), chromosomes)]

ins_chr_file = os.path.join(ins_dir, f"ins_{chrom}.fa")
write_ins_file(insertions, ins_chr_file)
concatenate_output(temporary_files, vcf_final_file)

mreps_dict = parse_mreps_result(run_mreps(ins_chr_file, args.mreps_res))
dust.update(mreps_dict)
if not args.save_temp:
os.remove(ins_chr_file)
if not args.save_temp:
logging.info("Cleaning up output directory.")
shutil.rmtree(args.ins_dir)
shutil.rmtree(args.vcf_dir)
logging.info("Enjoy your annotation.")

# TODO merge vcf files geneated for each chromosome (usefull for parallel implementation)

vcf_final_file = os.path.join(args.out_dir, "strdust-list.vcf")
vcfy(dust, vcf_final_file)
def run(args, region):
logging.info(f"-- Start processing: {region} --")
insertions = extract_insertions(args.bam, region, minlen=15,
mapq=10, merge_distance=args.distance, flank_distance=50)
insertions = merge_overlapping_insertions(sorted(insertions), merge_distance=args.distance)

# clean up
region_string = region.replace(':', '_').replace('-', '_')
ins_chr_file = os.path.join(args.ins_dir, f"ins_{region_string}.fa")
write_ins_file(insertions, ins_chr_file)
if not args.save_temp:
logging.info("Cleaning up output directory.")
shutil.rmtree(ins_dir)
shutil.rmtree(vcf_dir)
os.remove(ins_chr_file)

logging.info("Enjoy your annotation.")
mreps_dict = parse_mreps_result(run_mreps(ins_chr_file, args.mreps_res))
vcf_temporary_file = os.path.join(args.vcf_dir, f"strdust-{region_string}.tsv")
if mreps_dict:
vcfy(mreps_dict, vcf_temporary_file)
return vcf_temporary_file


def extract_insertions(bamf, chrom, minlen, mapq, merge_distance, flank_distance):
def extract_insertions(bamf, region, minlen, mapq, merge_distance, flank_distance):
"""
Extract insertions and softclips from a bam file based on parsing CIGAR strings
Expand All @@ -200,11 +116,11 @@ def extract_insertions(bamf, chrom, minlen, mapq, merge_distance, flank_distance
BAM_CBACK 9
"""

logging.info("Start extraction of insertions and softclips")
logging.info(f"{region}: Start extraction of insertions and softclips")

insertions = []
bam = pysam.AlignmentFile(bamf, "rb")
for read in bam.fetch(region=chrom, multiple_iterators=True):
for read in bam.fetch(region=region, multiple_iterators=True):
insertions_per_read = []
read_position = 0
reference_position = read.reference_start + 1
Expand All @@ -230,7 +146,7 @@ def extract_insertions(bamf, chrom, minlen, mapq, merge_distance, flank_distance
insertions_per_read = horizontal_merge(insertions_per_read, merge_distance)
insertions.extend(insertions_per_read)

logging.info("End with extraction of insertions and softclips.")
logging.info(f"{region}: End with extraction of insertions and softclips.")
return insertions


Expand All @@ -240,7 +156,6 @@ def get_haplotype(read):
return str(read.get_tag('HP')) if read.has_tag('HP') else 'un'


# PLEASE REVIEW FUNCTION BELOW
def horizontal_merge(insertions, merge_distance):
"""Merge insertions occuring in the same read if they are within merge_distance"""
insertions.sort()
Expand Down Expand Up @@ -379,7 +294,7 @@ def parse_mreps_result(mreps_output_str):
info_list = info.split("\t")
loc_list = re.findall(r'\d+', info_list[0])
seq_list = info_list[-1].split()
seq = max(seq_list, key = seq_list.count)
seq = max(seq_list, key=seq_list.count)
# seq = info_list[-1].split()[0]
ins_info = loc_list + [seq]
temp.append(ins_info)
Expand All @@ -398,10 +313,13 @@ def vcfy(mrep_dict, oufvcf):
"""
strdust_vcf = open(oufvcf, "w")
logging.info("Writing results to %s" % oufvcf)
strdust_vcf.write("#chrom\tstart\tend\trepeat_seq\tsize\n")
strdust_vcf.write("chrom\tstart\tend\trepeat_seq\tsize\n")

for dustspec in mrep_dict.keys():
[chrom, start_ins, end_ins] = dustspec.split("'")[1].split("_")
try:
[chrom, start_ins, end_ins] = dustspec.split("'")[1].split("_")
except ValueError:
sys.exit(dustspec)
start_ins = int(start_ins)
end_ins = int(end_ins)
# mreps can find more than on repeated seq
Expand All @@ -426,6 +344,19 @@ def vcfy(mrep_dict, oufvcf):
strdust_vcf.close()


def concatenate_output(temporary_files, output_file):
"""
Concatentate files in temporary_files,and sort by chromosome and start position
The run function returns just the file name, and the file may not have been created
when no variants were called. Therefore, checking if file exists before reading.
"""
pd.concat([pd.read_csv(f, sep="\t") for f in temporary_files if os.path.isfile(f)],
ignore_index=True) \
.sort_values(by=['chrom', 'start']) \
.to_csv(output_file, sep="\t", index=False)


def get_args():
parser = ArgumentParser("Genotype STRs from long reads")
parser.add_argument("bam", help="phased bam file")
Expand All @@ -439,13 +370,21 @@ def get_args():
help="tolerent error rate in mreps repeat finding",
type=int,
default=1)
parser.add_argument("-t", "--threads",
help="number of threads to use",
type=int,
default=8)
parser.add_argument("--save_temp", action="store_true",
dest="save_temp", default=False,
help="enable saving temporary files in output directory")
parser.add_argument("--debug", action="store_true",
dest="debug", default=False,
help="enable debug output")
parser.add_argument("--region", help="run on a specific interval only")
parser.add_argument("-v", "--version",
help="Print version and exit.",
action="version",
version=f'STRdust {__version__}')

return parser.parse_args()

Expand Down
83 changes: 83 additions & 0 deletions STRdust/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import sys
import logging
import tempfile
import pysam


class InputException(Exception):
pass


def setup_temp_dirs(out_dir):
ins_dir = os.path.join(out_dir, "chrs_ins_tmp")
if not os.path.isdir(ins_dir):
os.mkdir(ins_dir)

vcf_dir = os.path.join(out_dir, "chrs_vcf_tmp")
if not os.path.isdir(vcf_dir):
os.mkdir(vcf_dir)
return ins_dir, vcf_dir


def create_output_directory(out_dir):
if not os.path.isdir(out_dir):
os.mkdir(out_dir)
else:
path_to_dir = os.path.join(out_dir, "test")
if not _validate_path(path_to_dir):
sys.exit(f"Problem with writing permissions in output directory. {path_to_dir}\n")
return os.path.abspath(out_dir)


def _validate_path(path):
try:
os.makedirs(path, exist_ok=True)
temp_dir_path = tempfile.mkdtemp(dir=path)
os.rmdir(temp_dir_path)
return True
except OSError:
return False


def _enable_logging(out_dir, debug, overwrite):
"""
Turns on logging, sets debug levels and assigns a log file
"""
logger = logging.getLogger()
log_file = os.path.join(out_dir, "STRdust.log")
log_formatter = logging.Formatter("[%(asctime)s] %(name)s: %(levelname)s: "
"%(message)s", "%Y-%m-%d %H:%M:%S")
console_formatter = logging.Formatter("[%(asctime)s] %(levelname)s: "
"%(message)s", "%Y-%m-%d %H:%M:%S")

console_log = logging.StreamHandler()
console_log.setFormatter(console_formatter)

if not debug:
console_log.setLevel(logging.INFO)

if overwrite:
open(log_file, "w").close()

file_handler = logging.FileHandler(log_file, mode="a")
file_handler.setFormatter(log_formatter)

logger.setLevel(logging.DEBUG)
logger.addHandler(console_log)
logger.addHandler(file_handler)


def _check_bam_files(bam_file):
"""
Check existance of input files and generate index file if it is absent
:param bam_file: phased bam file with/without bai file
"""

if not os.path.exists(bam_file):
raise InputException(f"Can't open {bam_file}")

samfile = pysam.AlignmentFile(bam_file, "rb")
if not samfile.has_index():
logging.info("Input bam file does not have index file (.bai). Generating now.")
pysam.index(bam_file)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
keywords='STR long reads',
packages=find_packages(),
python_requires='>=3',
install_requires=['pysam', 'pyspoa'],
install_requires=['pysam', 'pyspoa', 'pandas'],
package_data={'STRdust': []},
package_dir={'STRdust': 'STRdust'},
include_package_data=True,
Expand Down

0 comments on commit ef7cd87

Please sign in to comment.