From fe15297bc2d2d6e08ef4063d8355554dab0ededc Mon Sep 17 00:00:00 2001
From: Douglas Wu <wckdouglas@utexas.edu>
Date: Thu, 13 Dec 2018 14:47:14 -0600
Subject: [PATCH] added option for error correction algorithm

---
 bin/pe_fq_merge.py                        |  5 +++-
 sequencing_tools/fastq_tools/pe_align.pyx | 32 ++++++++++++++++++-----
 2 files changed, 29 insertions(+), 8 deletions(-)

diff --git a/bin/pe_fq_merge.py b/bin/pe_fq_merge.py
index 854024a..86d19d6 100644
--- a/bin/pe_fq_merge.py
+++ b/bin/pe_fq_merge.py
@@ -22,6 +22,9 @@ def getOptions():
             help='Maximum error rate of alignment (default: 0.1)')
     parser.add_argument('-a','--all',action='store_true',
         help='Output all bases (default: only overlapping regions)')
+    parser.add_argument('-c','--conserved', action='store_true', 
+                        help = 'Use of a voting algorithm, '\
+                        'otherwise use posterior error from qualit')
     return parser.parse_args()
 
 def main():
@@ -29,7 +32,7 @@ def main():
     outfile=args.outfile
     outfile_handle = sys.stdout if outfile == '-' or outfile == '/dev/stdin' else xopen(outfile,mode = 'w')
     merge_interleaved(args.interleaved, outfile_handle,
-            args.min_len, args.error, args.all)
+            args.min_len, args.error, args.all, args.conserved)
 
 
 if __name__ == '__main__':
diff --git a/sequencing_tools/fastq_tools/pe_align.pyx b/sequencing_tools/fastq_tools/pe_align.pyx
index 0c6ffa7..d06b3df 100644
--- a/sequencing_tools/fastq_tools/pe_align.pyx
+++ b/sequencing_tools/fastq_tools/pe_align.pyx
@@ -7,6 +7,9 @@ import sys
 from builtins import zip
 from functools import partial
 from cpython cimport bool
+from sequencing_tools.bam_tools.read_cluster import calculate_concensus_base, prob_to_qual_string
+cdef:
+    double EPSILON = 0.999999
 
 
 cdef calibrate_qual(str b1, str b2, str q1, str q2):
@@ -38,7 +41,20 @@ cdef calibrate_qual(str b1, str b2, str q1, str q2):
         qual = 0
     return base, chr(qual + 33)
 
-        
+cdef posterior_error(str r1_seq, str r1_qual, str r2_seq, str r2_qual):
+    cdef:
+        str seq='', qual=''
+        str b1, b2, q1, q2
+        double q
+    
+    iterator = zip(r1_seq, r1_qual, r2_seq, r2_qual)
+    for b1, q1, b2, q2 in iterator:
+        b, correct_prob = calculate_concensus_base(([b1, b2], [q1,q2], 0))
+        seq += b
+        qual += prob_to_qual_string(correct_prob)
+    return seq, qual
+
+
 
 cdef correct_error(str r1_seq, str r1_qual, str r2_seq, str r2_qual):
     '''
@@ -74,7 +90,7 @@ cdef correct_error(str r1_seq, str r1_qual, str r2_seq, str r2_qual):
 
 
 cdef make_concensus(float error_toleration, int min_len, 
-                bool report_all,
+                bool report_all, concensus_function, 
                 fastqRecord R1, fastqRecord R2):
     '''
     reverse complement read2 sequence and find matching position on read1
@@ -88,20 +104,21 @@ cdef make_concensus(float error_toleration, int min_len,
         str right_add_seq =''
         str left_add_qual = ''
         str right_add_qual = ''
+        bool no_indel
 
     
     r1_id = R1.id.split('/')[0]
-
     r2_seq = reverse_complement(R2.seq)
     r2_qual = R2.qual[::-1]
     aligned = locate(R1.seq, r2_seq, error_toleration)
     if aligned:
         r1_start, r1_end, r2_start, r2_end, match, err = aligned 
-        if match >= min_len:
+        no_indel =(r1_end - r1_start) == (r2_end - r2_start)
+        if match >= min_len and no_indel:
 #                print(aligned, file=sys.stdout)
 #                print(R1.seq[r1_start:r1_end], file=sys.stdout)
 #                print(r2_seq[r2_start:r2_end], file=sys.stdout)
-            seq, qual = correct_error(R1.seq[r1_start:r1_end], 
+            seq, qual = concensus_function(R1.seq[r1_start:r1_end], 
                                     R1.qual[r1_start:r1_end], 
                                     r2_seq[r2_start:r2_end],
                                     r2_qual[r2_start:r2_end])
@@ -120,7 +137,7 @@ cdef make_concensus(float error_toleration, int min_len,
     return out_line
 
 
-def merge_interleaved(infile, outfile_handle, min_len, error_toleration, report_all):
+def merge_interleaved(infile, outfile_handle, min_len, error_toleration, report_all, conserved=False):
     cdef:
         fastqRecord R1, R2
         int record_count = 0
@@ -128,7 +145,8 @@ def merge_interleaved(infile, outfile_handle, min_len, error_toleration, report_
 
     infile_handle = sys.stdin if infile == '-' or infile == '/dev/stdin' else xopen(infile,mode = 'r')
 
-    concensus_builder = partial(make_concensus, error_toleration, min_len, report_all)
+    concensus_function = partial(posterior_error) if not conserved else partial(correct_error)
+    concensus_builder = partial(make_concensus, error_toleration, min_len, report_all, concensus_function)
 
     for R1, R2 in read_interleaved(infile_handle):
         out_line = concensus_builder(R1, R2)