diff --git a/eskrim/eskrim.py b/eskrim/eskrim.py index 7b20175..4cb5d32 100644 --- a/eskrim/eskrim.py +++ b/eskrim/eskrim.py @@ -9,14 +9,15 @@ import sys import re from collections import namedtuple -import fileinput +from fileinput import FileInput import itertools import random import subprocess import multiprocessing import multiprocessing.pool -from tempfile import NamedTemporaryFile from importlib.metadata import version +from tempfile import NamedTemporaryFile +from typing import cast, IO, TextIO, Generator try: import dna_jellyfish except ImportError as import_err: @@ -24,13 +25,13 @@ eskrim_version = "ESKRIM v" + version("eskrim") -def setup_logger(): +def setup_logger() -> None: logger = logging.getLogger() logger.setLevel(logging.INFO) logging.basicConfig(format='%(asctime)s :: %(levelname)s :: %(message)s') -def check_program_available(program): +def check_program_available(program) -> None: try: subprocess.call([program], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) except OSError: @@ -38,7 +39,7 @@ def check_program_available(program): sys.exit(1) -def num_threads_type(value): +def num_threads_type(value) -> int: max_num_threads = multiprocessing.cpu_count() try: @@ -67,7 +68,7 @@ def readable_writable_dir(path): return path -def get_parameters(): +def get_parameters() -> argparse.Namespace: """Parse command line parameters. """ parser = argparse.ArgumentParser(description=__doc__, @@ -120,19 +121,19 @@ def get_parameters(): return parser.parse_args() -def hook_compressed_text(filename, mode, encoding='utf8'): +def hook_compressed_text(filename, mode, encoding = "utf8") -> TextIO: ext = os.path.splitext(filename)[1] if ext == '.gz': import gzip - return gzip.open(filename, mode + 't', encoding=encoding) + return cast(TextIO, gzip.open(filename, mode + 't', encoding=encoding)) if ext == '.bz2': import bz2 - return bz2.open(filename, mode + 't', encoding=encoding) + return cast(TextIO, bz2.open(filename, mode + 't', encoding=encoding)) - return open(filename, mode, encoding=encoding) + return cast(TextIO, open(filename, mode, encoding=encoding)) -def check_fastq_files(fastq_files): +def check_fastq_files(fastq_files: list[str]) -> None: fastq_extensions = { '.fastq', '.fq', '.fastq.gz', '.fq.gz', '.fastq.bz2', '.fq.bz2'} fastq_extensions_escape = ( @@ -143,18 +144,17 @@ def check_fastq_files(fastq_files): problematic_fastq_files = [os.path.basename(fastq_file) for fastq_file in fastq_files if re.sub(regexp_match_fastq_extensions, '', fastq_file).endswith( ('_2', '.2', '_R2'))] - problematic_fastq_files = ','.join(problematic_fastq_files) if problematic_fastq_files: logging.warning('Input FASTQ files probably contain reverse reads (%s)', - problematic_fastq_files) + ','.join(problematic_fastq_files)) logging.warning('Use only forward reads for accurate results\n') FastqEntry = namedtuple('FastqEntry', ['name', 'seq', 'qual']) -def fastq_reader(fastq_fi, target_read_length): +def fastq_reader(fastq_fi: FileInput[str], target_read_length: int) -> Generator[FastqEntry, None, None]: while fastq_fi: try: name = next(fastq_fi).rstrip('\n') @@ -186,12 +186,12 @@ def fastq_reader(fastq_fi, target_read_length): fastq_reader.num_too_short_reads_ignored = 0 -def fastq_formatter(fastq_entry): +def fastq_formatter(fastq_entry: FastqEntry) -> str: return f'{fastq_entry.name}\n{fastq_entry.seq}\n+\n{fastq_entry.qual}\n' -def subsample_fastq_files(input_fastq_files, target_num_reads, target_read_length): - with fileinput.FileInput(input_fastq_files, openhook=hook_compressed_text) as fastq_fi: +def subsample_fastq_files(input_fastq_files: list[str], target_num_reads: int, target_read_length: int) -> list[FastqEntry]: + with FileInput(input_fastq_files, openhook=hook_compressed_text) as fastq_fi: # Fill reservoir selected_reads = list( itertools.islice(fastq_reader(fastq_fi, target_read_length), target_num_reads)) @@ -221,7 +221,7 @@ def subsample_fastq_files(input_fastq_files, target_num_reads, target_read_lengt return selected_reads -def create_jf_db(reads, kmer_length, num_threads, temp_dir): +def create_jf_db(reads: list[FastqEntry], kmer_length: int, num_threads: int, temp_dir: str) -> IO[bytes]: jellyfish_db_file = NamedTemporaryFile( dir=temp_dir, prefix="eskrim_", suffix=".jf", delete=True, delete_on_close=False ) @@ -233,6 +233,7 @@ def create_jf_db(reads, kmer_length, num_threads, temp_dir): '/dev/stdin'] with subprocess.Popen(jellyfish_count_cmd, stdin=subprocess.PIPE) as jellyfish_count_proc: + assert(jellyfish_count_proc.stdin is not None) for read in reads: jellyfish_count_proc.stdin.write(fastq_formatter(read).encode()) @@ -244,36 +245,38 @@ def create_jf_db(reads, kmer_length, num_threads, temp_dir): return jellyfish_db_file -def count_distinct_kmers(jellyfish_db_filename): +def count_distinct_kmers(jellyfish_db_filename: str) -> int: jellyfish_stats_cmd = ['jellyfish', 'stats', jellyfish_db_filename] with subprocess.Popen(jellyfish_stats_cmd, stdout=subprocess.PIPE) as jellyfish_stats_proc: + assert(jellyfish_stats_proc.stdout is not None) jellyfish_stats_proc.stdout.readline() - num_distinct_kmers = jellyfish_stats_proc.stdout.readline() + num_distinct_kmers_str = jellyfish_stats_proc.stdout.readline() jellyfish_stats_proc.wait() - num_distinct_kmers = int(num_distinct_kmers.split()[-1]) + num_distinct_kmers = int(num_distinct_kmers_str.split()[-1]) return num_distinct_kmers -def count_solid_kmers(jellyfish_db_filename): +def count_solid_kmers(jellyfish_db_filename: str) -> int: jellyfish_stats_cmd = ['jellyfish', 'stats', '-L', '2', jellyfish_db_filename] with subprocess.Popen(jellyfish_stats_cmd, stdout=subprocess.PIPE) as jellyfish_stats_proc: + assert(jellyfish_stats_proc.stdout is not None) jellyfish_stats_proc.stdout.readline() - num_solid_kmers = jellyfish_stats_proc.stdout.readline() + num_solid_kmers_str = jellyfish_stats_proc.stdout.readline() jellyfish_stats_proc.wait() - num_solid_kmers = int(num_solid_kmers.split()[-1]) + num_solid_kmers = int(num_solid_kmers_str.split()[-1]) return num_solid_kmers -def count_mercy_kmers_aux(params): +def count_mercy_kmers_aux(params: tuple) -> int: reads, jellyfish_db_filename, read_length, kmer_length = params num_mercy_kmers = 0 jellyfish_db = dna_jellyfish.QueryMerFile(jellyfish_db_filename) @@ -295,7 +298,7 @@ def count_mercy_kmers_aux(params): return num_mercy_kmers -def count_mercy_kmers(reads, jellyfish_db_filename, read_length, kmer_length, num_threads): +def count_mercy_kmers(reads: list[FastqEntry], jellyfish_db_filename: str, read_length: int, kmer_length: int, num_threads: int) -> int: chunk_size = 200000 chunks_parameters = ((reads[x:x + chunk_size], jellyfish_db_filename, read_length, kmer_length) for x in range(0, len(reads), chunk_size)) @@ -307,7 +310,7 @@ def count_mercy_kmers(reads, jellyfish_db_filename, read_length, kmer_length, nu return final_num_mercy_kmers -def main(): +def main() -> None: setup_logger() parameters = get_parameters()