From ff6bf18d9baacd97908b86e6dff4cecff92bb5a3 Mon Sep 17 00:00:00 2001 From: Dillon Barker Date: Wed, 4 Nov 2020 15:25:37 -0600 Subject: [PATCH] add input validation --- fsac/__init__.py | 2 +- fsac/main.py | 67 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/fsac/__init__.py b/fsac/__init__.py index 3a1909a..9051a19 100644 --- a/fsac/__init__.py +++ b/fsac/__init__.py @@ -1,3 +1,3 @@ -__version__ = '1.2.0' +__version__ = '1.2.1' __author__ = 'Dillon Barker' __email__ = 'dillon.barker@canada.ca' diff --git a/fsac/main.py b/fsac/main.py index dd444ba..0d48fa1 100644 --- a/fsac/main.py +++ b/fsac/main.py @@ -1,4 +1,6 @@ import argparse +import itertools +import json import logging import os import sys @@ -6,7 +8,7 @@ from . import __version__ from .allele_call import allele_call -from .update import update_directory +from .update import update_directory, get_known_alleles from .tabulate import tabulate_calls # Ensure numpy, via pandas, doesn't use more than 1 thread. @@ -122,19 +124,82 @@ def main(): args.func(args) +def validate_fasta(fasta_path: Path): + + if not fasta_path.is_file(): + return (1, f"{fasta_path} does not exist") + + try: + get_known_alleles(fasta_path) + + except UnboundLocalError: + return (1, f"{fasta_path} is not in FASTA format") + + except UnicodeDecodeError: + return (1, f"{fasta_path} is not in FASTA format") + + return (0, "") + + +def validate_json(json_path: Path): + + if not json_path.is_file(): + return (1, f"{json_path} does not exist") + try: + with json_path.open("r") as f: + data = json.load(f) + return (0, "") + + except json.decoder.JSONDecodeError: + return (1, f"{json_path} is not a valid JSON file") + + +def validate_directory(dir_path: Path, validation_method): + + if not dir_path.is_dir(): + return [(1, f"{dir_path} is not a directory")] + + results = [validation_method(p) for p in dir_path.glob("*")] + + return results + + +def validate(*args): + + errors, messages = zip(*itertools.chain(*args)) + + n_errors = sum(errors) + + if n_errors > 0: + + print(f"Got {n_errors} input errors:") + print('\n'.join(filter(None, messages))) + print("Exiting.") + sys.exit(n_errors) + + def call_alleles(args): + validate([validate_fasta(args.input)], + validate_directory(args.alleles, validate_fasta)) + allele_call(args.input, args.alleles, args.output) def update_results(args): + validate(validate_directory(args.json_dir, validate_json), + validate_directory(args.alleles, validate_fasta), + validate_directory(args.genome_dir, validate_fasta)) + update_directory(args.json_dir, args.alleles, args.threshold, args.genome_dir) def tabulate_allele_calls(args): + validate(validate_directory(args.json_dir, validate_json)) + tabulate_calls(args.json_dir, args.output, args.delimiter)