diff --git a/isONcorrect b/isONcorrect index 6758a38..b995899 100755 --- a/isONcorrect +++ b/isONcorrect @@ -16,6 +16,7 @@ from collections import deque from collections import defaultdict import edlib +import parasail from modules import create_augmented_reference, help_functions, correct_seqs #,align @@ -23,13 +24,43 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs) +def rindex(lst, value): + return len(lst) - operator.indexOf(reversed(lst), value) - 1 def get_kmer_minimizers(seq, k_size, w_size): + # kmers = [seq[i:i+k_size] for i in range(len(seq)-k_size) ] + w = w_size - k_size + window_kmers = deque([hash(seq[i:i+k_size]) for i in range(w +1)]) + curr_min = min(window_kmers) + minimizer_pos = rindex(list(window_kmers), curr_min) + minimizers = [ (seq[minimizer_pos: minimizer_pos+k_size], minimizer_pos) ] # get the last element if ties in window + + for i in range(w+1,len(seq) - k_size): + new_kmer = hash(seq[i:i+k_size]) + # updateing window + discarded_kmer = window_kmers.popleft() + window_kmers.append(new_kmer) + + # we have discarded previous window's minimizer, look for new minimizer brute force + if curr_min == discarded_kmer and minimizer_pos < i - w: + curr_min = min(window_kmers) + minimizer_pos = rindex(list(window_kmers), curr_min) + i - w + minimizers.append( (seq[minimizer_pos: minimizer_pos+k_size], minimizer_pos) ) # get the last element if ties in window + + # Previous minimizer still in window, we only need to compare with the recently added kmer + elif new_kmer < curr_min: + curr_min = new_kmer + minimizers.append( (seq[i: i+k_size], i) ) + + return minimizers + +def get_kmer_minimizers_lex(seq, k_size, w_size): # kmers = [seq[i:i+k_size] for i in range(len(seq)-k_size) ] w = w_size - k_size window_kmers = deque([seq[i:i+k_size] for i in range(w +1)]) curr_min = min(window_kmers) - minimizers = [ (curr_min, list(window_kmers).index(curr_min)) ] + minimizer_pos = rindex(list(window_kmers), curr_min) + minimizers = [ (seq[minimizer_pos: minimizer_pos+k_size], minimizer_pos) ] # get the last element if ties in window for i in range(w+1,len(seq) - k_size): new_kmer = seq[i:i+k_size] @@ -37,15 +68,16 @@ def get_kmer_minimizers(seq, k_size, w_size): discarded_kmer = window_kmers.popleft() window_kmers.append(new_kmer) - # we have discarded previous windows minimizer, look for new minimizer brute force - if curr_min == discarded_kmer: + # we have discarded previous window's minimizer, look for new minimizer brute force + if curr_min == discarded_kmer and minimizer_pos < i - w: curr_min = min(window_kmers) - minimizers.append( (curr_min, list(window_kmers).index(curr_min) + i - w ) ) + minimizer_pos = rindex(list(window_kmers), curr_min) + i - w + minimizers.append( (seq[minimizer_pos: minimizer_pos+k_size], minimizer_pos) ) # get the last element if ties in window # Previous minimizer still in window, we only need to compare with the recently added kmer elif new_kmer < curr_min: curr_min = new_kmer - minimizers.append( (curr_min, i) ) + minimizers.append( (seq[i: i+k_size], i) ) return minimizers @@ -83,8 +115,10 @@ def get_minimizers_and_positions_compressed(reads, w, k, hash_fcn): seq_hpol_comp = ''.join(ch for ch, _ in itertools.groupby(seq)) - if hash_fcn == "lex": + if hash_fcn == "random": minimizers = get_kmer_minimizers(seq_hpol_comp, k, w) + elif hash_fcn == "lex": + minimizers = get_kmer_minimizers_lex(seq_hpol_comp, k, w) elif hash_fcn == "rev_lex": minimizers = get_kmer_maximizers(seq_hpol_comp, k, w) @@ -101,8 +135,10 @@ def get_minimizers_and_positions(reads, w, k, hash_fcn): M = {} for r_id in reads: (acc, seq, qual) = reads[r_id] - if hash_fcn == "lex": + if hash_fcn == "random": minimizers = get_kmer_minimizers(seq, k, w) + elif hash_fcn == "lex": + minimizers = get_kmer_minimizers_lex(seq, k, w) elif hash_fcn == "rev_lex": minimizers = get_kmer_maximizers(seq, k, w) @@ -141,6 +177,7 @@ def get_minimizer_combinations_database(reads, M, k, x_low, x_high): avg_bundance = 0 singleton_minimzer = 0 + high_abundance = 0 cnt = 1 abundants=[] for m1 in list(M2.keys()): @@ -154,13 +191,15 @@ def get_minimizer_combinations_database(reads, M, k, x_low, x_high): if len(M2[m1][m2])// 3 > len(reads): abundants.append((m1,m2, len(M2[m1][m2])//3 )) - if m2 == forbidden: # poly A tail + if m2 == forbidden or len(M2[m1][m2])// 3 > 10*len(reads): # poly A tail or highly abundant del M2[m1][m2] + high_abundance += 1 for m1,m2,ab in sorted(abundants, key=lambda x: x[2], reverse=True): - print("Too abundant:", m1, m2, ab, len(reads)) + print("Not unique within reads:", m1, m2, ab, len(reads)) print("Average abundance for non-unique minimizer-combs:", avg_bundance/float(cnt)) print("Number of singleton minimizer combinations filtered out:", singleton_minimzer) + print("Number of highly abundant minimizer combinations (10x more frequent than nr reads) or poly-A anchors filtered out :", high_abundance) return M2 @@ -1037,14 +1076,13 @@ def find_most_supported_span(r_id, m1, p1, m1_curr_spans, minimizer_combinations curr_ref_start, curr_ref_end, curr_read_start, curr_read_end, curr_ed = already_computed[relevant_read_id] if (curr_read_start <= pos1 and pos2 <= curr_read_end) and (curr_ref_start <= p1 and p2 <= curr_ref_end): p_error_read = (quality_values_database[relevant_read_id][pos2 + k_size] - quality_values_database[relevant_read_id][pos1])/(pos2 + k_size - pos1) - p_error_sum_thresh = p_error_ref + p_error_read # curr_p_error_sum_thresh*len(ref_seq) + p_error_sum_thresh = (p_error_ref + p_error_read)*len(ref_seq) #max(8, (p_error_ref + p_error_read)*(1/3)*len(ref_seq)) # roughly a 1/3 of the errors are indels # curr_p_error_sum_thresh*len(ref_seq) read_beg_diff = pos1 - curr_read_start read_end_diff = pos2 - curr_read_end ref_beg_diff = p1 - curr_ref_start ref_end_diff = p2 - curr_ref_end - ed_est = curr_ed + math.fabs(ref_end_diff - read_end_diff) + math.fabs(read_beg_diff - ref_beg_diff) - if 0 <= ed_est <= p_error_sum_thresh*len(ref_seq): # < curr_p_error_sum_thresh*len(ref_seq): + if 0 <= ed_est <= p_error_sum_thresh: # max(8, p_error_sum_thresh*len(ref_seq)): # seqs[relevant_read_id] = (pos1, pos2) # add_items(seqs, relevant_read_id, pos1, pos2) if relevant_read_id in to_add and ed_est >= to_add[relevant_read_id][3]: @@ -1063,8 +1101,12 @@ def find_most_supported_span(r_id, m1, p1, m1_curr_spans, minimizer_combinations p_error_read = (quality_values_database[relevant_read_id][pos2 + k_size] - quality_values_database[relevant_read_id][pos1])/(pos2 + k_size - pos1) - p_error_sum_thresh = p_error_ref + p_error_read #sum([D[char_] for char_ in read_qual])/len(read_qual) #+ 0.1 - editdist = edlib_alignment(ref_seq, read_seq, p_error_sum_thresh*len(ref_seq)) + # p_error_sum_thresh = (p_error_ref + p_error_read)*len(ref_seq) #max(8, (p_error_ref + p_error_read)*(1/3)*len(ref_seq)) # max(8,(p_error_ref + p_error_read)*(1/3)*len(ref_seq)) # roughly a 1/3 of the errors are indels #sum([D[char_] for char_ in read_qual])/len(read_qual) #+ 0.1 + editdist = edlib_alignment(ref_seq, read_seq, len(ref_seq)) + # p_error_sum_thresh = p_error_ref + p_error_read #sum([D[char_] for char_ in read_qual])/len(read_qual) #+ 0.1 + # if p_error_sum_thresh*len(ref_seq) < 5: + # print(p_error_sum_thresh*len(ref_seq), p_error_sum_thresh,len(ref_seq)) + # editdist = edlib_alignment(ref_seq, read_seq, p_error_sum_thresh*len(ref_seq)) tmp_cnt += 1 if editdist >= 0: # passing second edit distance check @@ -1119,6 +1161,106 @@ def get_intervals_to_correct(opt_indicies, all_intervals_sorted_by_finish): return intervals_to_correct + + +def cigar_to_seq(cigar, query, ref): + cigar_tuples = [] + result = re.split(r'[=DXSMI]+', cigar) + i = 0 + for length in result[:-1]: + i += len(length) + type_ = cigar[i] + i += 1 + cigar_tuples.append((int(length), type_ )) + + r_index = 0 + q_index = 0 + q_aln = [] + r_aln = [] + for length_ , type_ in cigar_tuples: + if type_ == "=" or type_ == "X": + q_aln.append(query[q_index : q_index + length_]) + r_aln.append(ref[r_index : r_index + length_]) + + r_index += length_ + q_index += length_ + + elif type_ == "I": + # insertion w.r.t. reference + r_aln.append('-' * length_) + q_aln.append(query[q_index: q_index + length_]) + # only query index change + q_index += length_ + + elif type_ == 'D': + # deletion w.r.t. reference + r_aln.append(ref[r_index: r_index + length_]) + q_aln.append('-' * length_) + # only ref index change + r_index += length_ + + else: + print("error") + print(cigar) + sys.exit() + + return "".join([s for s in q_aln]), "".join([s for s in r_aln]), cigar_tuples + + +def parasail_alignment(s1, s2, match_score = 2, mismatch_penalty = -2, opening_penalty = 24, gap_ext = 1): + user_matrix = parasail.matrix_create("ACGT", match_score, mismatch_penalty) + result = parasail.sg_trace_scan_16(s1, s2, opening_penalty, gap_ext, user_matrix) + if result.saturated: + result = parasail.sg_trace_scan_32(s1, s2, opening_penalty, gap_ext, user_matrix) + + # difference in how to obtain string from parasail between python v2 and v3... + if sys.version_info[0] < 3: + cigar_string = str(result.cigar.decode).decode('utf-8') + else: + cigar_string = str(result.cigar.decode, 'utf-8') + s1_alignment, s2_alignment, cigar_tuples = cigar_to_seq(cigar_string, s1, s2) + # print() + # print(s1_alignment) + # print(s2_alignment) + # print(cigar_string) + return s1_alignment, s2_alignment, cigar_string, cigar_tuples, result.score + + +def fix_correction(orig, corr): + seq = [] + o_segm = [] + c_segm = [] + l = 0 + for o, c in zip(orig,corr): + if o != '-' and c != '-': + if l > 10: # take original read segment + seq.append( ''.join([x for x in o_segm if x != '-']) ) + elif l > 0: # take corrected read segment + seq.append( ''.join([x for x in c_segm if x != '-']) ) + seq.append(c) + l=0 + o_segm = [] + c_segm = [] + elif o == '-': + c_segm.append(c) + l += 1 + elif c == '-': + o_segm.append(o) + l += 1 + else: + raise("Parsing alignment error of parasail's alignment") + + # if ending in an indel + if l > 10: # take original read segment + seq.append( ''.join([x for x in o_segm if x != '-']) ) + elif l > 0: # take corrected read segment + seq.append( ''.join([x for x in c_segm if x != '-']) ) + l=0 + o_segm = [] + c_segm = [] + return ''.join([s for s in seq]) + + def correct_read(seq, reads, intervals_to_correct, k_size, work_dir, v_depth_ratio_threshold, max_seqs_to_spoa, disable_numpy, verbose, use_racon): corr_seq = [] # print(opt_indicies) @@ -1161,7 +1303,15 @@ def correct_read(seq, reads, intervals_to_correct, k_size, work_dir, v_depth_rat tmp.append( seq[ stop_ : corr_seq[cnt+1][0]] ) corr = "".join([s for s in tmp]) - return corr, other_reads_corrected_regions + + # check for structural overcorrections + start = time() + s1_alignment, s2_alignment, cigar_string, cigar_tuples, score = parasail_alignment(seq, corr, match_score=4, mismatch_penalty=-8, opening_penalty=12, gap_ext=1) + # print('Alignment took: ', time() - start ) + adjusted_corr = fix_correction(s1_alignment, s2_alignment) + s1_alignment, s2_alignment, cigar_string, cigar_tuples, score = parasail_alignment(seq, adjusted_corr, match_score=4, mismatch_penalty=-8, opening_penalty=12, gap_ext=1) + + return adjusted_corr, other_reads_corrected_regions D = {chr(i) : min( 10**( - (ord(chr(i)) - 33)/10.0 ), 0.79433) for i in range(128)} @@ -1281,6 +1431,8 @@ def main(args): tmp_cnt = 0 for r_id in sorted(reads): #, reverse=True): + # print() + # print(reads[r_id][0]) if args.randstrobes: seq = reads[r_id][1] # print("seq length:", len(seq)) @@ -1400,7 +1552,6 @@ def main(args): corrected_seq = seq else: intervals_to_correct = get_intervals_to_correct(opt_indicies[::-1], all_intervals) - # print(r_id, intervals_to_correct) del all_intervals all_intervals = [] corrected_seq, other_reads_corrected_regions = correct_read(seq, reads, intervals_to_correct, k_size, work_dir, v_depth_ratio_threshold, max_seqs_to_spoa, args.disable_numpy, args.verbose, args.use_racon) @@ -1464,7 +1615,7 @@ def main(args): if __name__ == '__main__': parser = argparse.ArgumentParser(description="De novo error correction of long-read transcriptome reads", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--version', action='version', version='%(prog)s 0.0.8') + parser.add_argument('--version', action='version', version='%(prog)s 0.1.0') parser.add_argument('--fastq', type=str, default=False, help='Path to input fastq file with reads') # parser.add_argument('--t', dest="nr_cores", type=int, default=8, help='Number of cores allocated for clustering') diff --git a/setup.py b/setup.py index 285e1d8..ac13c5a 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setup( name='isONcorrect', # Required - version='0.0.8', # Required + version='0.1.0', # Required description='De novo error-correction of long-read transcriptome reads.', # Required long_description=long_description, # Optional long_description_content_type='text/markdown',