From 8559c649ed0653a8c2db4c1aca783f3b6383bddb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20Plaza=20O=C3=B1ate?= Date: Wed, 24 Jul 2024 14:19:26 +0200 Subject: [PATCH] Use builtin openhook --- eskrim/eskrim.py | 32 +++++++++----------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/eskrim/eskrim.py b/eskrim/eskrim.py index 6b232e0..e9b9f38 100644 --- a/eskrim/eskrim.py +++ b/eskrim/eskrim.py @@ -9,7 +9,7 @@ import sys import re from collections import namedtuple -from fileinput import FileInput +from fileinput import FileInput, hook_compressed import itertools import random import subprocess @@ -33,7 +33,7 @@ def setup_logger() -> None: logging.basicConfig(format="%(asctime)s :: %(levelname)s :: %(message)s") -def check_program_available(program) -> None: +def check_program_available(program: str) -> None: try: subprocess.call([program], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) except OSError: @@ -41,23 +41,23 @@ def check_program_available(program) -> None: sys.exit(1) -def num_threads_type(value) -> int: +def num_threads_type(value: str) -> int: max_num_threads = multiprocessing.cpu_count() try: - value = int(value) + num_threads = int(value) except ValueError as value_err: raise argparse.ArgumentTypeError("NUM_THREADS is not an integer") from value_err - if value <= 0: + if num_threads <= 0: raise argparse.ArgumentTypeError("minimum NUM_THREADS is 1") - if value > max_num_threads: + if num_threads > max_num_threads: raise argparse.ArgumentTypeError(f"maximum NUM_THREADS is {max_num_threads}") - return value + return num_threads -def readable_writable_dir(path): +def readable_writable_dir(path: str): if not os.path.isdir(path): raise NotADirectoryError(path) @@ -158,20 +158,6 @@ def get_parameters() -> argparse.Namespace: return parser.parse_args() -def hook_compressed_text(filename, mode, encoding="utf8") -> TextIO: - ext = os.path.splitext(filename)[1] - if ext == ".gz": - import gzip - - return cast(TextIO, gzip.open(filename, mode + "t", encoding=encoding)) - if ext == ".bz2": - import bz2 - - return cast(TextIO, bz2.open(filename, mode + "t", encoding=encoding)) - - return cast(TextIO, open(filename, mode, encoding=encoding)) - - def check_fastq_files(fastq_files: list[str]) -> None: fastq_extensions = {".fastq", ".fq", ".fastq.gz", ".fq.gz", ".fastq.bz2", ".fq.bz2"} fastq_extensions_escape = ( @@ -239,7 +225,7 @@ def fastq_formatter(fastq_entry: FastqEntry) -> str: def subsample_fastq_files( input_fastq_files: list[str], target_num_reads: int, target_read_length: int ) -> list[FastqEntry]: - with FileInput(input_fastq_files, openhook=hook_compressed_text) as fastq_fi: + with FileInput(input_fastq_files, openhook=hook_compressed, encoding="utf-8") as fastq_fi: # Fill reservoir selected_reads = list( itertools.islice(fastq_reader(fastq_fi, target_read_length), target_num_reads)