Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
ksahlin committed Sep 6, 2023
2 parents 843b12d + 8e3ad06 commit 6e83042
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 18 deletions.
185 changes: 168 additions & 17 deletions isONcorrect
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,68 @@ from collections import deque
from collections import defaultdict

import edlib
import parasail

from modules import create_augmented_reference, help_functions, correct_seqs #,align

def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)


def rindex(lst, value):
return len(lst) - operator.indexOf(reversed(lst), value) - 1

def get_kmer_minimizers(seq, k_size, w_size):
# kmers = [seq[i:i+k_size] for i in range(len(seq)-k_size) ]
w = w_size - k_size
window_kmers = deque([hash(seq[i:i+k_size]) for i in range(w +1)])
curr_min = min(window_kmers)
minimizer_pos = rindex(list(window_kmers), curr_min)
minimizers = [ (seq[minimizer_pos: minimizer_pos+k_size], minimizer_pos) ] # get the last element if ties in window

for i in range(w+1,len(seq) - k_size):
new_kmer = hash(seq[i:i+k_size])
# updateing window
discarded_kmer = window_kmers.popleft()
window_kmers.append(new_kmer)

# we have discarded previous window's minimizer, look for new minimizer brute force
if curr_min == discarded_kmer and minimizer_pos < i - w:
curr_min = min(window_kmers)
minimizer_pos = rindex(list(window_kmers), curr_min) + i - w
minimizers.append( (seq[minimizer_pos: minimizer_pos+k_size], minimizer_pos) ) # get the last element if ties in window

# Previous minimizer still in window, we only need to compare with the recently added kmer
elif new_kmer < curr_min:
curr_min = new_kmer
minimizers.append( (seq[i: i+k_size], i) )

return minimizers

def get_kmer_minimizers_lex(seq, k_size, w_size):
# kmers = [seq[i:i+k_size] for i in range(len(seq)-k_size) ]
w = w_size - k_size
window_kmers = deque([seq[i:i+k_size] for i in range(w +1)])
curr_min = min(window_kmers)
minimizers = [ (curr_min, list(window_kmers).index(curr_min)) ]
minimizer_pos = rindex(list(window_kmers), curr_min)
minimizers = [ (seq[minimizer_pos: minimizer_pos+k_size], minimizer_pos) ] # get the last element if ties in window

for i in range(w+1,len(seq) - k_size):
new_kmer = seq[i:i+k_size]
# updateing window
discarded_kmer = window_kmers.popleft()
window_kmers.append(new_kmer)

# we have discarded previous windows minimizer, look for new minimizer brute force
if curr_min == discarded_kmer:
# we have discarded previous window's minimizer, look for new minimizer brute force
if curr_min == discarded_kmer and minimizer_pos < i - w:
curr_min = min(window_kmers)
minimizers.append( (curr_min, list(window_kmers).index(curr_min) + i - w ) )
minimizer_pos = rindex(list(window_kmers), curr_min) + i - w
minimizers.append( (seq[minimizer_pos: minimizer_pos+k_size], minimizer_pos) ) # get the last element if ties in window

# Previous minimizer still in window, we only need to compare with the recently added kmer
elif new_kmer < curr_min:
curr_min = new_kmer
minimizers.append( (curr_min, i) )
minimizers.append( (seq[i: i+k_size], i) )

return minimizers

Expand Down Expand Up @@ -83,8 +115,10 @@ def get_minimizers_and_positions_compressed(reads, w, k, hash_fcn):

seq_hpol_comp = ''.join(ch for ch, _ in itertools.groupby(seq))

if hash_fcn == "lex":
if hash_fcn == "random":
minimizers = get_kmer_minimizers(seq_hpol_comp, k, w)
elif hash_fcn == "lex":
minimizers = get_kmer_minimizers_lex(seq_hpol_comp, k, w)
elif hash_fcn == "rev_lex":
minimizers = get_kmer_maximizers(seq_hpol_comp, k, w)

Expand All @@ -101,8 +135,10 @@ def get_minimizers_and_positions(reads, w, k, hash_fcn):
M = {}
for r_id in reads:
(acc, seq, qual) = reads[r_id]
if hash_fcn == "lex":
if hash_fcn == "random":
minimizers = get_kmer_minimizers(seq, k, w)
elif hash_fcn == "lex":
minimizers = get_kmer_minimizers_lex(seq, k, w)
elif hash_fcn == "rev_lex":
minimizers = get_kmer_maximizers(seq, k, w)

Expand Down Expand Up @@ -141,6 +177,7 @@ def get_minimizer_combinations_database(reads, M, k, x_low, x_high):

avg_bundance = 0
singleton_minimzer = 0
high_abundance = 0
cnt = 1
abundants=[]
for m1 in list(M2.keys()):
Expand All @@ -154,13 +191,15 @@ def get_minimizer_combinations_database(reads, M, k, x_low, x_high):

if len(M2[m1][m2])// 3 > len(reads):
abundants.append((m1,m2, len(M2[m1][m2])//3 ))
if m2 == forbidden: # poly A tail
if m2 == forbidden or len(M2[m1][m2])// 3 > 10*len(reads): # poly A tail or highly abundant
del M2[m1][m2]
high_abundance += 1
for m1,m2,ab in sorted(abundants, key=lambda x: x[2], reverse=True):
print("Too abundant:", m1, m2, ab, len(reads))
print("Not unique within reads:", m1, m2, ab, len(reads))

print("Average abundance for non-unique minimizer-combs:", avg_bundance/float(cnt))
print("Number of singleton minimizer combinations filtered out:", singleton_minimzer)
print("Number of highly abundant minimizer combinations (10x more frequent than nr reads) or poly-A anchors filtered out :", high_abundance)

return M2

Expand Down Expand Up @@ -1037,14 +1076,13 @@ def find_most_supported_span(r_id, m1, p1, m1_curr_spans, minimizer_combinations
curr_ref_start, curr_ref_end, curr_read_start, curr_read_end, curr_ed = already_computed[relevant_read_id]
if (curr_read_start <= pos1 and pos2 <= curr_read_end) and (curr_ref_start <= p1 and p2 <= curr_ref_end):
p_error_read = (quality_values_database[relevant_read_id][pos2 + k_size] - quality_values_database[relevant_read_id][pos1])/(pos2 + k_size - pos1)
p_error_sum_thresh = p_error_ref + p_error_read # curr_p_error_sum_thresh*len(ref_seq)
p_error_sum_thresh = (p_error_ref + p_error_read)*len(ref_seq) #max(8, (p_error_ref + p_error_read)*(1/3)*len(ref_seq)) # roughly a 1/3 of the errors are indels # curr_p_error_sum_thresh*len(ref_seq)
read_beg_diff = pos1 - curr_read_start
read_end_diff = pos2 - curr_read_end
ref_beg_diff = p1 - curr_ref_start
ref_end_diff = p2 - curr_ref_end

ed_est = curr_ed + math.fabs(ref_end_diff - read_end_diff) + math.fabs(read_beg_diff - ref_beg_diff)
if 0 <= ed_est <= p_error_sum_thresh*len(ref_seq): # < curr_p_error_sum_thresh*len(ref_seq):
if 0 <= ed_est <= p_error_sum_thresh: # max(8, p_error_sum_thresh*len(ref_seq)):
# seqs[relevant_read_id] = (pos1, pos2)
# add_items(seqs, relevant_read_id, pos1, pos2)
if relevant_read_id in to_add and ed_est >= to_add[relevant_read_id][3]:
Expand All @@ -1063,8 +1101,12 @@ def find_most_supported_span(r_id, m1, p1, m1_curr_spans, minimizer_combinations


p_error_read = (quality_values_database[relevant_read_id][pos2 + k_size] - quality_values_database[relevant_read_id][pos1])/(pos2 + k_size - pos1)
p_error_sum_thresh = p_error_ref + p_error_read #sum([D[char_] for char_ in read_qual])/len(read_qual) #+ 0.1
editdist = edlib_alignment(ref_seq, read_seq, p_error_sum_thresh*len(ref_seq))
# p_error_sum_thresh = (p_error_ref + p_error_read)*len(ref_seq) #max(8, (p_error_ref + p_error_read)*(1/3)*len(ref_seq)) # max(8,(p_error_ref + p_error_read)*(1/3)*len(ref_seq)) # roughly a 1/3 of the errors are indels #sum([D[char_] for char_ in read_qual])/len(read_qual) #+ 0.1
editdist = edlib_alignment(ref_seq, read_seq, len(ref_seq))
# p_error_sum_thresh = p_error_ref + p_error_read #sum([D[char_] for char_ in read_qual])/len(read_qual) #+ 0.1
# if p_error_sum_thresh*len(ref_seq) < 5:
# print(p_error_sum_thresh*len(ref_seq), p_error_sum_thresh,len(ref_seq))
# editdist = edlib_alignment(ref_seq, read_seq, p_error_sum_thresh*len(ref_seq))

tmp_cnt += 1
if editdist >= 0: # passing second edit distance check
Expand Down Expand Up @@ -1119,6 +1161,106 @@ def get_intervals_to_correct(opt_indicies, all_intervals_sorted_by_finish):

return intervals_to_correct



def cigar_to_seq(cigar, query, ref):
cigar_tuples = []
result = re.split(r'[=DXSMI]+', cigar)
i = 0
for length in result[:-1]:
i += len(length)
type_ = cigar[i]
i += 1
cigar_tuples.append((int(length), type_ ))

r_index = 0
q_index = 0
q_aln = []
r_aln = []
for length_ , type_ in cigar_tuples:
if type_ == "=" or type_ == "X":
q_aln.append(query[q_index : q_index + length_])
r_aln.append(ref[r_index : r_index + length_])

r_index += length_
q_index += length_

elif type_ == "I":
# insertion w.r.t. reference
r_aln.append('-' * length_)
q_aln.append(query[q_index: q_index + length_])
# only query index change
q_index += length_

elif type_ == 'D':
# deletion w.r.t. reference
r_aln.append(ref[r_index: r_index + length_])
q_aln.append('-' * length_)
# only ref index change
r_index += length_

else:
print("error")
print(cigar)
sys.exit()

return "".join([s for s in q_aln]), "".join([s for s in r_aln]), cigar_tuples


def parasail_alignment(s1, s2, match_score = 2, mismatch_penalty = -2, opening_penalty = 24, gap_ext = 1):
user_matrix = parasail.matrix_create("ACGT", match_score, mismatch_penalty)
result = parasail.sg_trace_scan_16(s1, s2, opening_penalty, gap_ext, user_matrix)
if result.saturated:
result = parasail.sg_trace_scan_32(s1, s2, opening_penalty, gap_ext, user_matrix)

# difference in how to obtain string from parasail between python v2 and v3...
if sys.version_info[0] < 3:
cigar_string = str(result.cigar.decode).decode('utf-8')
else:
cigar_string = str(result.cigar.decode, 'utf-8')
s1_alignment, s2_alignment, cigar_tuples = cigar_to_seq(cigar_string, s1, s2)
# print()
# print(s1_alignment)
# print(s2_alignment)
# print(cigar_string)
return s1_alignment, s2_alignment, cigar_string, cigar_tuples, result.score


def fix_correction(orig, corr):
seq = []
o_segm = []
c_segm = []
l = 0
for o, c in zip(orig,corr):
if o != '-' and c != '-':
if l > 10: # take original read segment
seq.append( ''.join([x for x in o_segm if x != '-']) )
elif l > 0: # take corrected read segment
seq.append( ''.join([x for x in c_segm if x != '-']) )
seq.append(c)
l=0
o_segm = []
c_segm = []
elif o == '-':
c_segm.append(c)
l += 1
elif c == '-':
o_segm.append(o)
l += 1
else:
raise("Parsing alignment error of parasail's alignment")

# if ending in an indel
if l > 10: # take original read segment
seq.append( ''.join([x for x in o_segm if x != '-']) )
elif l > 0: # take corrected read segment
seq.append( ''.join([x for x in c_segm if x != '-']) )
l=0
o_segm = []
c_segm = []
return ''.join([s for s in seq])


def correct_read(seq, reads, intervals_to_correct, k_size, work_dir, v_depth_ratio_threshold, max_seqs_to_spoa, disable_numpy, verbose, use_racon):
corr_seq = []
# print(opt_indicies)
Expand Down Expand Up @@ -1161,7 +1303,15 @@ def correct_read(seq, reads, intervals_to_correct, k_size, work_dir, v_depth_rat
tmp.append( seq[ stop_ : corr_seq[cnt+1][0]] )

corr = "".join([s for s in tmp])
return corr, other_reads_corrected_regions

# check for structural overcorrections
start = time()
s1_alignment, s2_alignment, cigar_string, cigar_tuples, score = parasail_alignment(seq, corr, match_score=4, mismatch_penalty=-8, opening_penalty=12, gap_ext=1)
# print('Alignment took: ', time() - start )
adjusted_corr = fix_correction(s1_alignment, s2_alignment)
s1_alignment, s2_alignment, cigar_string, cigar_tuples, score = parasail_alignment(seq, adjusted_corr, match_score=4, mismatch_penalty=-8, opening_penalty=12, gap_ext=1)

return adjusted_corr, other_reads_corrected_regions


D = {chr(i) : min( 10**( - (ord(chr(i)) - 33)/10.0 ), 0.79433) for i in range(128)}
Expand Down Expand Up @@ -1281,6 +1431,8 @@ def main(args):
tmp_cnt = 0

for r_id in sorted(reads): #, reverse=True):
# print()
# print(reads[r_id][0])
if args.randstrobes:
seq = reads[r_id][1]
# print("seq length:", len(seq))
Expand Down Expand Up @@ -1400,7 +1552,6 @@ def main(args):
corrected_seq = seq
else:
intervals_to_correct = get_intervals_to_correct(opt_indicies[::-1], all_intervals)
# print(r_id, intervals_to_correct)
del all_intervals
all_intervals = []
corrected_seq, other_reads_corrected_regions = correct_read(seq, reads, intervals_to_correct, k_size, work_dir, v_depth_ratio_threshold, max_seqs_to_spoa, args.disable_numpy, args.verbose, args.use_racon)
Expand Down Expand Up @@ -1464,7 +1615,7 @@ def main(args):

if __name__ == '__main__':
parser = argparse.ArgumentParser(description="De novo error correction of long-read transcriptome reads", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--version', action='version', version='%(prog)s 0.0.8')
parser.add_argument('--version', action='version', version='%(prog)s 0.1.0')

parser.add_argument('--fastq', type=str, default=False, help='Path to input fastq file with reads')
# parser.add_argument('--t', dest="nr_cores", type=int, default=8, help='Number of cores allocated for clustering')
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
setup(

name='isONcorrect', # Required
version='0.0.8', # Required
version='0.1.0', # Required
description='De novo error-correction of long-read transcriptome reads.', # Required
long_description=long_description, # Optional
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 6e83042

Please sign in to comment.