diff --git a/gene_splicer/utils.py b/gene_splicer/utils.py index 87e6220..a658b91 100644 --- a/gene_splicer/utils.py +++ b/gene_splicer/utils.py @@ -12,6 +12,8 @@ import glob from pathlib import Path from csv import DictWriter, DictReader +from itertools import groupby +from operator import itemgetter logger = logging.getLogger(__name__) @@ -412,39 +414,26 @@ def align(target_seq, ] def iterate_hivintact_data(name, outpath): - intact = {} + intact = set() def get_verdict(SEQID, all_errors): ordered = sorted(all_errors, key=HIVINTACT_ERRORS_TABLE.index) verdict = ordered[0] return [SEQID, verdict] - for d in glob.glob(str(outpath / 'hivintact*')): + for d in outpath.glob('hivintact*'): for (SEQID, sequence) in read_fasta(os.path.join(d, 'intact.fasta')): - row = [SEQID, 'Intact'] - intact[SEQID] = True - yield row + yield [SEQID, 'Intact'] + intact.add(SEQID) - sequence_name = None with open(os.path.join(d, 'errors.csv'), 'r') as f: reader = csv.DictReader(f) + grouped = groupby(reader, key=itemgetter('sequence_name')) + for sequence_name, errors in grouped: + if sequence_name not in intact: + all_errors = [error['error'] for error in errors] + yield get_verdict(sequence_name, all_errors) - last_name = None - all_errors = [] - for row in reader: - sequence_name = row['sequence_name'] - if sequence_name in intact: continue - - if last_name != sequence_name and last_name is not None: - if all_errors: - yield get_verdict(last_name, all_errors) - all_errors = [] - - all_errors.append(row['error']) - last_name = sequence_name - - if all_errors: - yield get_verdict(sequence_name, all_errors) def get_hivintact_data(name, outpath): column_names = ['SEQID', 'MyVerdict'] @@ -572,16 +561,14 @@ def generate_proviral_landscape_csv(outpath, is_hivintact): landscape_rows = [] table_precursor_csv = os.path.join(outpath, 'table_precursor.csv') - blastn_csv = glob.glob( - os.path.join(outpath, 'hivintact*', 'blast.csv') \ - if is_hivintact else \ - os.path.join( - outpath, - 'hivseqinr*', - 'Results_Intermediate', - 'Output_Blastn_HXB2MEGA28_tabdelim.txt' - ) - )[0] + + if is_hivintact: + subpath = os.path.join(outpath, 'hivintact*', 'blast.csv') + else: + subpath = os.path.join(outpath, 'hivseqinr*', 'Results_Intermediate', 'Output_Blastn_HXB2MEGA28_tabdelim.txt') + + blastn_csvs = glob.glob(subpath) + blastn_csv = blastn_csvs[0] with open(blastn_csv, 'r') as blastn_file: if is_hivintact: