Skip to content

Commit

Permalink
Add typehints
Browse files Browse the repository at this point in the history
  • Loading branch information
fplazaonate committed Jul 23, 2024
1 parent d18ebbc commit 58facc1
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions eskrim/eskrim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,37 @@
import sys
import re
from collections import namedtuple
import fileinput
from fileinput import FileInput
import itertools
import random
import subprocess
import multiprocessing
import multiprocessing.pool
from tempfile import NamedTemporaryFile
from importlib.metadata import version
from tempfile import NamedTemporaryFile
from typing import cast, IO, TextIO, Generator
try:
import dna_jellyfish
except ImportError as import_err:
raise RuntimeError("Python bindings of jellyfish are not installed") from import_err

eskrim_version = "ESKRIM v" + version("eskrim")

def setup_logger():
def setup_logger() -> None:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format='%(asctime)s :: %(levelname)s :: %(message)s')


def check_program_available(program):
def check_program_available(program) -> None:
try:
subprocess.call([program], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except OSError:
logging.error('%s not found or not in system path', program)
sys.exit(1)


def num_threads_type(value):
def num_threads_type(value) -> int:
max_num_threads = multiprocessing.cpu_count()

try:
Expand Down Expand Up @@ -67,7 +68,7 @@ def readable_writable_dir(path):
return path


def get_parameters():
def get_parameters() -> argparse.Namespace:
"""Parse command line parameters.
"""
parser = argparse.ArgumentParser(description=__doc__,
Expand Down Expand Up @@ -120,19 +121,19 @@ def get_parameters():
return parser.parse_args()


def hook_compressed_text(filename, mode, encoding='utf8'):
def hook_compressed_text(filename, mode, encoding = "utf8") -> TextIO:
ext = os.path.splitext(filename)[1]
if ext == '.gz':
import gzip
return gzip.open(filename, mode + 't', encoding=encoding)
return cast(TextIO, gzip.open(filename, mode + 't', encoding=encoding))
if ext == '.bz2':
import bz2
return bz2.open(filename, mode + 't', encoding=encoding)
return cast(TextIO, bz2.open(filename, mode + 't', encoding=encoding))

return open(filename, mode, encoding=encoding)
return cast(TextIO, open(filename, mode, encoding=encoding))


def check_fastq_files(fastq_files):
def check_fastq_files(fastq_files: list[str]) -> None:
fastq_extensions = {
'.fastq', '.fq', '.fastq.gz', '.fq.gz', '.fastq.bz2', '.fq.bz2'}
fastq_extensions_escape = (
Expand All @@ -143,18 +144,17 @@ def check_fastq_files(fastq_files):
problematic_fastq_files = [os.path.basename(fastq_file) for fastq_file in fastq_files
if re.sub(regexp_match_fastq_extensions, '', fastq_file).endswith(
('_2', '.2', '_R2'))]
problematic_fastq_files = ','.join(problematic_fastq_files)

if problematic_fastq_files:
logging.warning('Input FASTQ files probably contain reverse reads (%s)',
problematic_fastq_files)
','.join(problematic_fastq_files))
logging.warning('Use only forward reads for accurate results\n')


FastqEntry = namedtuple('FastqEntry', ['name', 'seq', 'qual'])


def fastq_reader(fastq_fi, target_read_length):
def fastq_reader(fastq_fi: FileInput[str], target_read_length: int) -> Generator[FastqEntry, None, None]:
while fastq_fi:
try:
name = next(fastq_fi).rstrip('\n')
Expand Down Expand Up @@ -186,12 +186,12 @@ def fastq_reader(fastq_fi, target_read_length):
fastq_reader.num_too_short_reads_ignored = 0


def fastq_formatter(fastq_entry):
def fastq_formatter(fastq_entry: FastqEntry) -> str:
return f'{fastq_entry.name}\n{fastq_entry.seq}\n+\n{fastq_entry.qual}\n'


def subsample_fastq_files(input_fastq_files, target_num_reads, target_read_length):
with fileinput.FileInput(input_fastq_files, openhook=hook_compressed_text) as fastq_fi:
def subsample_fastq_files(input_fastq_files: list[str], target_num_reads: int, target_read_length: int) -> list[FastqEntry]:
with FileInput(input_fastq_files, openhook=hook_compressed_text) as fastq_fi:
# Fill reservoir
selected_reads = list(
itertools.islice(fastq_reader(fastq_fi, target_read_length), target_num_reads))
Expand Down Expand Up @@ -221,7 +221,7 @@ def subsample_fastq_files(input_fastq_files, target_num_reads, target_read_lengt
return selected_reads


def create_jf_db(reads, kmer_length, num_threads, temp_dir):
def create_jf_db(reads: list[FastqEntry], kmer_length: int, num_threads: int, temp_dir: str) -> IO[bytes]:
jellyfish_db_file = NamedTemporaryFile(
dir=temp_dir, prefix="eskrim_", suffix=".jf", delete=True, delete_on_close=False
)
Expand All @@ -233,6 +233,7 @@ def create_jf_db(reads, kmer_length, num_threads, temp_dir):
'/dev/stdin']

with subprocess.Popen(jellyfish_count_cmd, stdin=subprocess.PIPE) as jellyfish_count_proc:
assert(jellyfish_count_proc.stdin is not None)
for read in reads:
jellyfish_count_proc.stdin.write(fastq_formatter(read).encode())

Expand All @@ -244,36 +245,38 @@ def create_jf_db(reads, kmer_length, num_threads, temp_dir):
return jellyfish_db_file


def count_distinct_kmers(jellyfish_db_filename):
def count_distinct_kmers(jellyfish_db_filename: str) -> int:
jellyfish_stats_cmd = ['jellyfish', 'stats',
jellyfish_db_filename]

with subprocess.Popen(jellyfish_stats_cmd, stdout=subprocess.PIPE) as jellyfish_stats_proc:
assert(jellyfish_stats_proc.stdout is not None)
jellyfish_stats_proc.stdout.readline()
num_distinct_kmers = jellyfish_stats_proc.stdout.readline()
num_distinct_kmers_str = jellyfish_stats_proc.stdout.readline()
jellyfish_stats_proc.wait()

num_distinct_kmers = int(num_distinct_kmers.split()[-1])
num_distinct_kmers = int(num_distinct_kmers_str.split()[-1])

return num_distinct_kmers


def count_solid_kmers(jellyfish_db_filename):
def count_solid_kmers(jellyfish_db_filename: str) -> int:
jellyfish_stats_cmd = ['jellyfish', 'stats',
'-L', '2',
jellyfish_db_filename]

with subprocess.Popen(jellyfish_stats_cmd, stdout=subprocess.PIPE) as jellyfish_stats_proc:
assert(jellyfish_stats_proc.stdout is not None)
jellyfish_stats_proc.stdout.readline()
num_solid_kmers = jellyfish_stats_proc.stdout.readline()
num_solid_kmers_str = jellyfish_stats_proc.stdout.readline()
jellyfish_stats_proc.wait()

num_solid_kmers = int(num_solid_kmers.split()[-1])
num_solid_kmers = int(num_solid_kmers_str.split()[-1])

return num_solid_kmers


def count_mercy_kmers_aux(params):
def count_mercy_kmers_aux(params: tuple) -> int:
reads, jellyfish_db_filename, read_length, kmer_length = params
num_mercy_kmers = 0
jellyfish_db = dna_jellyfish.QueryMerFile(jellyfish_db_filename)
Expand All @@ -295,7 +298,7 @@ def count_mercy_kmers_aux(params):
return num_mercy_kmers


def count_mercy_kmers(reads, jellyfish_db_filename, read_length, kmer_length, num_threads):
def count_mercy_kmers(reads: list[FastqEntry], jellyfish_db_filename: str, read_length: int, kmer_length: int, num_threads: int) -> int:
chunk_size = 200000
chunks_parameters = ((reads[x:x + chunk_size], jellyfish_db_filename, read_length, kmer_length)
for x in range(0, len(reads), chunk_size))
Expand All @@ -307,7 +310,7 @@ def count_mercy_kmers(reads, jellyfish_db_filename, read_length, kmer_length, nu
return final_num_mercy_kmers


def main():
def main() -> None:
setup_logger()

parameters = get_parameters()
Expand Down

0 comments on commit 58facc1

Please sign in to comment.