From b0f8e0fe7c4fe478459e87ea8bbe81f777de394d Mon Sep 17 00:00:00 2001 From: Vitaliy Mysak Date: Tue, 31 Dec 2024 15:05:36 -0800 Subject: [PATCH] Factor out referencefull_contig_stitcher --- micall/core/contig_stitcher.py | 662 +----------------- micall/drivers/sample.py | 2 +- micall/tests/test_contig_stitcher.py | 20 +- micall/utils/referencefull_contig_stitcher.py | 658 +++++++++++++++++ micall/utils/referenceless_contig_stitcher.py | 9 +- 5 files changed, 681 insertions(+), 670 deletions(-) create mode 100644 micall/utils/referencefull_contig_stitcher.py diff --git a/micall/core/contig_stitcher.py b/micall/core/contig_stitcher.py index e4c452dae..1536d0569 100644 --- a/micall/core/contig_stitcher.py +++ b/micall/core/contig_stitcher.py @@ -1,669 +1,13 @@ -from typing import Iterable, Optional, Tuple, List, Dict, Literal, TypeVar, TextIO, Sequence -from collections import defaultdict -import csv -import os -from dataclasses import replace -from math import ceil -from functools import reduce -from itertools import tee, islice, chain -from gotoh import align_it -from queue import LifoQueue -from Bio import Seq +from typing import Sequence import logging -from fractions import Fraction -from operator import itemgetter -from aligntools import CigarHit, connect_nonoverlapping_cigar_hits, drop_overlapping_cigar_hits, CigarActions -from micall.core.project_config import ProjectConfig -from micall.core.plot_contigs import plot_stitcher_coverage -from micall.utils.contig_stitcher_context import context, StitcherContext -from micall.utils.contig_stitcher_contigs import Contig, GenotypedContig, AlignedContig -from micall.utils.consensus_aligner import align_consensus -import micall.utils.contig_stitcher_events as events +from micall.utils.referencefull_contig_stitcher \ + import referencefull_contig_stitcher -T = TypeVar("T") logger = logging.getLogger(__name__) -def log(e: events.EventType) -> None: - context.get().emit(e) - logger.debug("%s", e) - - -def cut_query(self: GenotypedContig, cut_point: float) -> Tuple[GenotypedContig, GenotypedContig]: - """ Cuts query sequence in two parts with cut_point between them. """ - - cut_point = max(0.0, cut_point) - left = replace(self, name=None, seq=self.seq[:ceil(cut_point)]) - right = replace(self, name=None, seq=self.seq[ceil(cut_point):]) - return left, right - - -def cut_reference(self: AlignedContig, cut_point: float) -> Tuple[AlignedContig, AlignedContig]: - """ Cuts this alignment in two parts with cut_point between them. """ - - alignment_left, alignment_right = self.alignment.cut_reference(cut_point) - left = replace(self, name=None, alignment=alignment_left) - right = replace(self, name=None, alignment=alignment_right) - log(events.Cut(self, left, right, cut_point)) - return left, right - - -def lstrip(self: AlignedContig) -> AlignedContig: - """ - Trims the query sequence of the contig from its beginning up to the start of the - alignment. The CIGAR alignment is also updated to reflect the trimming. - """ - - alignment = self.alignment.lstrip_reference().lstrip_query() - q_remainder, query = cut_query(self, alignment.q_st - 0.5) - alignment = alignment.translate(0, -1 * alignment.q_st) - result = AlignedContig.make(query, alignment, self.strand) - log(events.LStrip(self, result)) - return result - - -def rstrip(self: AlignedContig) -> AlignedContig: - """ - Trims the query sequence of the contig from its end based on the end of the - alignment. The CIGAR alignment is also updated to reflect the trimming. - """ - - alignment = self.alignment.rstrip_reference().rstrip_query() - query, q_remainder = cut_query(self, alignment.q_ei + 0.5) - result = AlignedContig.make(query, alignment, self.strand) - log(events.RStrip(self, result)) - return result - - -def overlap(a: AlignedContig, b: AlignedContig) -> bool: - def intervals_overlap(x, y): - return x[0] <= y[1] and x[1] >= y[0] - - if a.group_ref != b.group_ref: - return False - - return intervals_overlap((a.alignment.r_st, a.alignment.r_ei), - (b.alignment.r_st, b.alignment.r_ei)) - - -def munge(self: AlignedContig, other: AlignedContig) -> AlignedContig: - """ - Combines two adjacent contigs into a single contig by joining their - query sequences and alignments. - """ - - match_fraction = min(self.match_fraction, other.match_fraction) - ref_name = max([self, other], key=lambda x: x.alignment.ref_length).ref_name - query = GenotypedContig(seq=self.seq + other.seq, - name=None, - ref_name=ref_name, - group_ref=self.group_ref, - ref_seq=self.ref_seq, - match_fraction=match_fraction) - - self_alignment = self.alignment - other_alignment = \ - other.alignment.translate( - query_delta=(-1 * other.alignment.q_st + self.alignment.q_ei + 1), - reference_delta=0) - alignment = self_alignment.connect(other_alignment) - - ret = AlignedContig.make(query=query, alignment=alignment, strand=self.strand) - log(events.Munge(self, other, ret)) - return ret - - -def sliding_window(sequence: Iterable[T]) -> Iterable[Tuple[Optional[T], T, Optional[T]]]: - """ - Generate a three-element sliding window of a sequence. - - Each element generated contains a tuple with the previous item (None if the first item), - the current item, and the next item (None if the last item) in the sequence. - """ - - a, b, c = tee(sequence, 3) - prevs = chain([None], a) - nexts = chain(islice(c, 1, None), [None]) - return zip(prevs, b, nexts) - - -def combine_contigs(parts: List[AlignedContig]) -> AlignedContig: - """ - Combine a list of contigs into a single AlignedContig by trimming and merging overlapping parts. - - Left-trimming and right-trimming occur at any shared overlapping points - between adjacent parts. munge() is used to combine contiguous parts without overlap. - """ - - stripped_parts = [] - for prev_part, part, next_part in sliding_window(parts): - if prev_part is not None: - part = lstrip(part) - if next_part is not None: - part = rstrip(part) - stripped_parts.append(part) - - ret = reduce(munge, stripped_parts) - log(events.Combine(stripped_parts, ret)) - return ret - - -def align_to_reference(contig: GenotypedContig) -> Iterable[GenotypedContig]: - """ - Align a single Contig to its reference sequence, producing potentially multiple aligned contigs. - - If the reference sequence (ref_seq) is unavailable, the contig is returned unaltered. - Otherwise, alignments are performed and contigs corresponding to each alignment are yielded. - """ - - if contig.ref_seq is None: - log(events.NoRef(contig)) - yield contig - return - - alignments, _algo = align_consensus(contig.ref_seq, contig.seq) - hits = [x.to_cigar_hit() for x in alignments] - strands: List[Literal["forward", "reverse"]] = ["forward" if x.strand == 1 else "reverse" for x in alignments] - - for i, (hit, strand) in enumerate(zip(hits, strands)): - log(events.InitialHit(contig, i, hit, strand)) - - if not hits: - log(events.ZeroHits(contig)) - yield contig - return - - if len(set(strands)) > 1: - log(events.StrandConflict(contig)) - yield contig - return - - strand = strands[0] - if strand == "reverse": - rc = str(Seq.Seq(contig.seq).reverse_complement()) - original_contig = contig - new_contig = replace(contig, seq=rc) - contig = new_contig - hits = [replace(hit, q_st=len(rc)-hit.q_ei-1, q_ei=len(rc)-hit.q_st-1) for hit in hits] - - log(events.ReverseComplement(original_contig, new_contig)) - for i, (hit, strand) in enumerate(zip(hits, strands)): - log(events.InitialHit(contig, i, hit, strand)) - - def quality(x: CigarHit): - mlen = sum(1 for x in x.cigar.relax().iterate_operations() - if x == CigarActions.MATCH) - return (mlen, x.ref_length) - - filtered = list(drop_overlapping_cigar_hits(hits, quality)) - connected = list(connect_nonoverlapping_cigar_hits(filtered)) - log(events.HitNumber(contig, list(zip(hits, strands)), connected)) - - for i, single_hit in enumerate(connected): - query = replace(contig, name=None) - part = AlignedContig.make(query, single_hit, strand) - log(events.ConnectedHit(contig, part, i)) - yield part - - -def strip_conflicting_mappings(contigs: Iterable[GenotypedContig]) -> Iterable[GenotypedContig]: - contigs = list(contigs) - names = {contig.id: contig for contig in contigs} - - def get_indexes(id: int) -> Tuple[int, int]: - contig = names[id] - if isinstance(contig, AlignedContig): - return contig.alignment.q_st, contig.alignment.r_st - else: - return -1, -1 - - reference_sorted = list(sorted(names.keys(), key=lambda id: get_indexes(id)[1])) - query_sorted = list(sorted(names.keys(), key=lambda id: get_indexes(id)[0])) - - def is_out_of_order(id: int) -> bool: - return reference_sorted.index(id) != query_sorted.index(id) - - sorted_by_query = sorted(contigs, key=lambda contig: get_indexes(contig.id)) - for prev_contig, contig, next_contig in sliding_window(sorted_by_query): - if isinstance(contig, AlignedContig): - original = contig - start = prev_contig.alignment.q_ei + 1 if isinstance(prev_contig, AlignedContig) else 0 - end = next_contig.alignment.q_st - 1 if isinstance(next_contig, AlignedContig) else len(contig.seq) - 1 - - if prev_contig is not None or is_out_of_order(original.id): - contig = lstrip(contig) - log(events.InitialStrip(original, start, original.alignment.q_st - 1)) - if next_contig is not None or is_out_of_order(original.id): - contig = rstrip(contig) - log(events.InitialStrip(original, original.alignment.q_ei + 1, end)) - - yield contig - - -def align_all_to_reference(contigs: Iterable[GenotypedContig]) -> Iterable[GenotypedContig]: - """ - Align multiple contigs to their respective reference sequences. - - Applies align_to_reference to each contig in the given collection, - flattening the result into a single list. - """ - - groups = map(align_to_reference, contigs) - groups = map(strip_conflicting_mappings, groups) - for group in groups: - yield from group - - -def align_queries(seq1: str, seq2: str) -> Tuple[str, str]: - """ - Globally align two query sequences against each other - and return the resulting aligned sequences in MSA format. - """ - - gap_open_penalty = 15 - gap_extend_penalty = 3 - use_terminal_gap_penalty = 1 - aseq1, aseq2, score = \ - align_it( - seq1, seq2, - gap_open_penalty, - gap_extend_penalty, - use_terminal_gap_penalty) - - return aseq1, aseq2 - - -def find_all_overlapping_contigs(self: AlignedContig, aligned_contigs): - """ - Yield all contigs from a collection that overlap with a given contig. - Contigs are considered overlapping if they have overlapping intervals on the same reference genome. - """ - - for other in aligned_contigs: - if overlap(self, other): - yield other - - -def find_overlapping_contig(self: AlignedContig, aligned_contigs): - """ - Find the single contig in a collection that overlaps the most with a given contig. - It returns the contig with the maximum overlapped reference length with the given contig (self). - """ - - every = find_all_overlapping_contigs(self, aligned_contigs) - return max(every, key=lambda other: other.alignment.ref_length if other else 0, default=None) - - -def calculate_concordance(left: str, right: str) -> List[Fraction]: - """ - Calculate concordance for two given sequences using a sliding average. - - The function compares the two strings character by character, simultaneously from - both left to right and right to left, calculating a score that represents a moving - average of matches at each position. If characters match at a given position, - a score of 1 is added; otherwise, a score of 0 is added. The score is then - averaged with the previous scores using a weighted sliding average where the - current score has a weight of 1/3 and the accumulated score has a weight of 2/3. - This sliding average score is halved and then processed again, but in reverse direction. - - :param left: string representing first sequence - :param right: string representing second sequence - :return: list representing concordance ratio for each position - """ - - if len(left) != len(right): - raise ValueError("Can only calculate concordance for same sized sequences") - - result: List[Fraction] = [Fraction(0)] * len(left) - - def slide(start, end): - scores_sum = Fraction(0) - inputs = list(zip(left, right)) - increment = 1 if start <= end else -1 - - for i in range(start, end, increment): - (a, b) = inputs[i] - current = Fraction(1) if a == b else Fraction(0) - scores_sum = (scores_sum * 2 / 3 + current * 1 / 3) - result[i] += scores_sum / 2 - - # Slide forward, then in reverse, adding the scores at each position. - slide(0, len(left)) - slide(len(left) - 1, -1) - - return result - - -def disambiguate_concordance(concordance: List[Fraction]) -> Iterable[Tuple[Fraction, int]]: - for i, x in enumerate(concordance): - global_rank = i if i < len(concordance) / 2 else len(concordance) - i - 1 - yield x, global_rank - - -def concordance_to_cut_points(left_overlap, right_overlap, aligned_left, aligned_right, concordance): - """ Determine optimal cut points for stitching based on sequence concordance in the overlap region. """ - - concordance_d = list(disambiguate_concordance(concordance)) - sorted_concordance_indexes = [i for i, v in sorted(enumerate(concordance_d), - key=itemgetter(1), - reverse=True, - )] - - def remove_dashes(s: str): - return s.replace('-', '') - - for max_concordance_index in sorted_concordance_indexes: - aligned_left_q_index = len(remove_dashes(aligned_left[:max_concordance_index])) - aligned_right_q_index = right_overlap.alignment.query_length - \ - len(remove_dashes(aligned_right[max_concordance_index:])) + 1 - aligned_left_r_index = left_overlap.alignment.coordinate_mapping.query_to_ref.left_max(aligned_left_q_index) - if aligned_left_r_index is None: - aligned_left_r_index = left_overlap.alignment.r_st - 1 - aligned_right_r_index = right_overlap.alignment.coordinate_mapping.query_to_ref.right_min(aligned_right_q_index) - if aligned_right_r_index is None: - aligned_right_r_index = right_overlap.alignment.r_ei + 1 - if aligned_right_r_index > aligned_left_r_index: - return aligned_left_r_index + 0.5, aligned_right_r_index - 0.5, max_concordance_index - - return left_overlap.alignment.r_st - 1 + 0.5, right_overlap.alignment.r_ei + 1 - 0.5, 0 - - -def stitch_2_contigs(left, right): - """ - Stitch two contigs together into a single coherent contig. - - The function handles the overlap by cutting both contigs into segments, aligning the - overlapping segments, and then choosing the optimal stitching points based on sequence - concordance. Non-overlapping segments are retained as is. - """ - - # Cut in 4 parts. - left_remainder, left_overlap = cut_reference(left, right.alignment.r_st - 0.5) - right_overlap, right_remainder = cut_reference(right, left.alignment.r_ei + 0.5) - left_overlap = lstrip(rstrip(left_overlap)) - right_overlap = lstrip(rstrip(right_overlap)) - left_remainder = rstrip(left_remainder) - right_remainder = lstrip(right_remainder) - log(events.StitchCut(left, right, left_overlap, right_overlap, left_remainder, right_remainder)) - - # Align overlapping parts, then recombine based on concordance. - aligned_left, aligned_right = align_queries(left_overlap.seq, right_overlap.seq) - concordance = calculate_concordance(aligned_left, aligned_right) - aligned_left_cutpoint, aligned_right_cutpoint, max_concordance_index = \ - concordance_to_cut_points(left_overlap, right_overlap, aligned_left, aligned_right, concordance) - left_overlap_take, left_overlap_drop = cut_reference(left_overlap, aligned_left_cutpoint) - right_overlap_drop, right_overlap_take = cut_reference(right_overlap, aligned_right_cutpoint) - - # Log it. - average_concordance = Fraction(sum(concordance) / (len(concordance) or 1)) - cut_point_location_scaled = max_concordance_index / (((len(concordance) or 1) - 1) or 1) - log(events.Overlap(left, right, left_overlap, right_overlap, - left_remainder, right_remainder, left_overlap_take, - right_overlap_take, concordance, average_concordance, - max_concordance_index, cut_point_location_scaled)) - - return combine_contigs([left_remainder, left_overlap_take, right_overlap_take, right_remainder]) - - -def combine_overlaps(contigs: List[AlignedContig]) -> Iterable[AlignedContig]: - """ - Repeatedly combine all overlapping aligned contigs into an iterable collection of contiguous AlignedContigs. - It proceeds by iterating through sorted contigs and stitching any overlapping ones until none are left. - """ - - # Going left-to-right through aligned contigs. - contigs = list(sorted(contigs, key=lambda x: x.alignment.r_st)) - while contigs: - current = contigs.pop(0) - - # Find overlap. If there isn't one - we are done with the current contig. - overlapping_contig = find_overlapping_contig(current, contigs) - if not overlapping_contig: - log(events.NoOverlap(current)) - yield current - continue - - # Replace two contigs by their stitched version, then loop with it. - new_contig = stitch_2_contigs(current, overlapping_contig) - contigs.remove(overlapping_contig) - contigs.insert(0, new_contig) - log(events.Stitch(current, overlapping_contig, new_contig)) - - -def merge_intervals(intervals: List[Tuple[int, int]]) -> List[Tuple[int, int]]: - """ - Merge overlapping and adjacent intervals. - Note that intervals are inclusive. - - :param intervals: A list of intervals [start, end] where 'start' and 'end' are integers. - :return: A list of merged intervals. - """ - - if not intervals: - return [] - - # Sort intervals by their starting values - sorted_intervals = sorted(intervals, key=lambda x: x[0]) - - merged_intervals = [sorted_intervals[0]] - for current in sorted_intervals[1:]: - current_start, current_end = current - last_start, last_end = merged_intervals[-1] - if current_start <= last_end + 1: - # Extend the last interval if there is an overlap or if they are adjacent - merged_intervals[-1] = (min(last_start, current_start), max(last_end, current_end)) - else: - # Add the current interval if there is no overlap - merged_intervals.append(current) - - return merged_intervals - - -def find_covered_contig(contigs: List[AlignedContig]) -> Tuple[Optional[AlignedContig], List[AlignedContig]]: - """ - Find and return the first contig that is completely covered by other contigs. - - :param contigs: List of all aligned contigs to be considered. - :return: An AlignedContig if there is one completely covered by others, None otherwise. - """ - - def calculate_cumulative_coverage(others) -> List[Tuple[int, int]]: - intervals = [(contig.alignment.r_st, contig.alignment.r_ei) for contig in others] - merged_intervals = merge_intervals(intervals) - return merged_intervals - - for current in contigs: - current_interval = (current.alignment.r_st, current.alignment.r_ei) - - # Create a map of cumulative coverage for contigs - overlaping_contigs = [x for x in contigs if x.id != current.id and overlap(current, x)] - cumulative_coverage = calculate_cumulative_coverage(overlaping_contigs) - - # Check if the current contig is covered by the cumulative coverage intervals - if any((cover_interval[0] <= current_interval[0] and cover_interval[1] >= current_interval[1]) - for cover_interval in cumulative_coverage): - return current, overlaping_contigs - - return None, [] - - -def drop_completely_covered(contigs: List[AlignedContig]) -> List[AlignedContig]: - """ Filter out all contigs that are contained within other contigs. """ - - contigs = contigs[:] - while contigs: - covered, covering = find_covered_contig(contigs) - if covered: - contigs.remove(covered) - log(events.Drop(covered, covering)) - else: - break - - return contigs - - -def split_contigs_with_gaps(contigs: List[AlignedContig]) -> List[AlignedContig]: - """ - Split contigs at large gaps if those gaps are covered by other contigs in the list. - - A gap within a contig is considered large based on a pre-defined threshold. If another contig aligns - within that gap's range, the contig is split into two around the midpoint of the gap. - """ - - def covered_by(gap, other): - # Check if any 1 reference coordinate in gap is mapped in `other`. - gap_coords = gap.coordinate_mapping.ref_to_query.domain - cover_coords = set(other.alignment.coordinate_mapping.ref_to_query.keys()) - return not gap_coords.isdisjoint(cover_coords) - - def covered(self, gap): - return any(covered_by(gap, other) for other in contigs if other != self) - - def significant(gap): - # noinspection PyLongLine - # The size of the gap is unavoidably, to some point, arbitrary. Here we tried to adjust it to common gaps in HIV, as HIV is the primary test subject in MiCall. A notable feature of HIV-1 reverse transcription is the appearance of periodic deletions of approximately 21 nucleotides. These deletions have been reported to occur in the HIV-1 genome and are thought to be influenced by the structure of the viral RNA. Specifically, the secondary structures and foldings of the RNA can lead to pause sites for the reverse transcriptase, resulting in staggered alignment when the enzyme slips. This misalignment can cause the reverse transcriptase to "jump," leading to deletions in the newly synthesized DNA. The unusually high frequency of about 21-nucleotide deletions is believed to correspond to the pitch of the RNA helix, which reflects the spatial arrangement of the RNA strands. The 21 nucleotide cycle is an average measure and is thought to be associated with the length of one turn of the RNA helix, meaning that when reverse transcriptase slips and reattaches, it often does so one helical turn away from the original site. # noqa: E501 - return gap.ref_length > 21 - - def try_split(self: AlignedContig): - for gap in self.alignment.deletions(): - if not significant(gap): - # Really we do not want to split on every little deletion - # because that would mean that we would need to stitch - # overlaps around them. - # And we are likely to lose quality with every stitching operation. - # By skipping we assert that this gap is aligner's fault. - log(events.IgnoreGap(self, gap)) - continue - - if covered(self, gap): - midpoint = gap.r_st + (gap.r_ei - gap.r_st) / 2 + self.alignment.epsilon - left_part, right_part = cut_reference(self, midpoint) - left_part = rstrip(left_part) - right_part = lstrip(right_part) - - contigs.remove(self) - contigs.append(left_part) - contigs.append(right_part) - process_queue.put(right_part) - log(events.SplitGap(self, gap, left_part, right_part)) - return - - process_queue: LifoQueue = LifoQueue() - for contig in contigs: - process_queue.put(contig) - - while not process_queue.empty(): - contig = process_queue.get() - try_split(contig) - - return contigs - - -def stitch_contigs(contigs: Iterable[GenotypedContig]) -> Iterable[GenotypedContig]: - contigs = list(contigs) - for contig in contigs: - log(events.Intro(contig)) - contig.register() - - maybe_aligned = list(align_all_to_reference(contigs)) - - # Contigs that did not align do not need any more processing - yield from (x for x in maybe_aligned if not isinstance(x, AlignedContig)) - aligned = [x for x in maybe_aligned if isinstance(x, AlignedContig)] - - aligned = split_contigs_with_gaps(aligned) - aligned = drop_completely_covered(aligned) - yield from combine_overlaps(aligned) - - -GroupRef = Optional[str] - - -def stitch_consensus(contigs: Iterable[GenotypedContig]) -> Iterable[GenotypedContig]: - contigs = list(stitch_contigs(contigs)) - consensus_parts: Dict[GroupRef, List[AlignedContig]] = defaultdict(list) - - for contig in contigs: - if isinstance(contig, AlignedContig): - consensus_parts[contig.group_ref].append(contig) - else: - yield contig - - def combine(group_ref): - ctgs = sorted(consensus_parts[group_ref], key=lambda x: x.alignment.r_st) - result = combine_contigs(ctgs) - log(events.FinalCombine(ctgs, result)) - return result - - yield from map(combine, consensus_parts) - - -def write_contigs(output_csv: TextIO, contigs: Iterable[GenotypedContig]): - writer = csv.DictWriter(output_csv, - ['ref', 'match', 'group_ref', 'contig'], - lineterminator=os.linesep) - writer.writeheader() - for contig in contigs: - writer.writerow(dict(ref=contig.ref_name, - match=contig.match_fraction, - group_ref=contig.group_ref, - contig=contig.seq)) - - output_csv.flush() - - -def read_referenceless_contigs(input_csv: TextIO) -> Iterable[Contig]: - for row in csv.DictReader(input_csv): - seq = row['contig'] - yield Contig(name=None, seq=seq) - - -def read_referencefull_contigs(input_csv: TextIO) -> Iterable[GenotypedContig]: - projects = ProjectConfig.loadDefault() - - for row in csv.DictReader(input_csv): - seq = row['contig'] - ref_name = row['ref'] - group_ref = row['group_ref'] - match_fraction = float(row['match']) - - try: - ref_seq = projects.getGenotypeReference(group_ref) - except KeyError: - try: - ref_seq = projects.getReference(group_ref) - except KeyError: - ref_seq = None - - yield GenotypedContig(name=None, - seq=seq, - ref_name=ref_name, - group_ref=group_ref, - ref_seq=str(ref_seq) if ref_seq is not None else None, - match_fraction=match_fraction) - - -def referencefull_contig_stitcher(input_csv: TextIO, - output_csv: TextIO, - stitcher_plot_path: Optional[str], - ) -> int: - with StitcherContext.fresh() as ctx: - contigs = list(read_referencefull_contigs(input_csv)) - - if output_csv is not None or stitcher_plot_path is not None: - contigs = list(stitch_consensus(contigs)) - - if output_csv is not None: - write_contigs(output_csv, contigs) - - if stitcher_plot_path is not None: - plot_stitcher_coverage(ctx.events, stitcher_plot_path) - - return len(contigs) - - def main(argv: Sequence[str]): import argparse diff --git a/micall/drivers/sample.py b/micall/drivers/sample.py index 5a01b470f..e7be374f8 100644 --- a/micall/drivers/sample.py +++ b/micall/drivers/sample.py @@ -9,7 +9,6 @@ from micall.core.aln2counts import aln2counts from micall.core.amplicon_finder import write_merge_lengths_plot, merge_for_entropy from micall.core.cascade_report import CascadeReport -from micall.core.contig_stitcher import referencefull_contig_stitcher from micall.core.coverage_plots import coverage_plot, concordance_plot from micall.core.plot_contigs import plot_genome_coverage from micall.core.prelim_map import prelim_map @@ -21,6 +20,7 @@ from micall.g2p.fastq_g2p import fastq_g2p, DEFAULT_MIN_COUNT, MIN_VALID, MIN_VALID_PERCENT from micall.utils.driver_utils import makedirs from micall.utils.fasta_to_csv import fasta_to_csv +from micall.utils.referencefull_contig_stitcher import referencefull_contig_stitcher from contextlib import contextmanager logger = logging.getLogger(__name__) diff --git a/micall/tests/test_contig_stitcher.py b/micall/tests/test_contig_stitcher.py index 7b94b251e..df56305ed 100644 --- a/micall/tests/test_contig_stitcher.py +++ b/micall/tests/test_contig_stitcher.py @@ -7,8 +7,8 @@ from aligntools import CigarActions, CigarHit, Cigar -import micall.core.contig_stitcher as stitcher -from micall.core.contig_stitcher import ( +import micall.utils.referencefull_contig_stitcher as stitcher +from micall.utils.referencefull_contig_stitcher import ( split_contigs_with_gaps, stitch_contigs, GenotypedContig, @@ -27,7 +27,7 @@ from micall.tests.test_remap import load_projects # activates the "projects" fixture -logging.getLogger("micall.core.contig_stitcher").setLevel(logging.DEBUG) +logging.getLogger("micall.utils.referencefull_contig_stitcher").setLevel(logging.DEBUG) logging.getLogger("micall.core.plot_contigs").setLevel(logging.DEBUG) @@ -39,7 +39,7 @@ @pytest.fixture() def exact_aligner(monkeypatch): - monkeypatch.setattr("micall.core.contig_stitcher.align_consensus", mock_align_consensus) + monkeypatch.setattr("micall.utils.referencefull_contig_stitcher.align_consensus", mock_align_consensus) @pytest.fixture @@ -1395,7 +1395,7 @@ def mock_align(reference_seq: str, consensus: str) -> Tuple[List[MockAlignment], algorithm = 'mock' return (alignments, algorithm) - monkeypatch.setattr("micall.core.contig_stitcher.align_consensus", mock_align) + monkeypatch.setattr("micall.utils.referencefull_contig_stitcher.align_consensus", mock_align) ref = 'A' * 700 seq = 'C' * 600 @@ -1457,10 +1457,11 @@ def test_correct_stitching_of_one_normal_and_one_unknown(exact_aligner, visualiz def test_main_invocation(exact_aligner, tmp_path, hcv_db): + from micall.core.contig_stitcher import main pwd = os.path.dirname(__file__) contigs = os.path.join(pwd, "data", "exact_parts_contigs.csv") stitched_contigs = os.path.join(tmp_path, "stitched.csv") - stitcher.main([contigs, stitched_contigs, "--use-references", "yes"]) + main([contigs, stitched_contigs, "--use-references", "yes"]) assert os.path.exists(contigs) assert os.path.exists(stitched_contigs) @@ -1479,13 +1480,14 @@ def test_main_invocation(exact_aligner, tmp_path, hcv_db): def test_visualizer_simple(exact_aligner, tmp_path, hcv_db): + from micall.core.contig_stitcher import main pwd = os.path.dirname(__file__) contigs = os.path.join(pwd, "data", "exact_parts_contigs.csv") stitched_contigs = os.path.join(tmp_path, "stitched.csv") plot = os.path.join(tmp_path, "exact_parts_contigs.plot.svg") - stitcher.main([contigs, stitched_contigs, - "--debug", "--plot", plot, - "--use-references", "yes"]) + main([contigs, stitched_contigs, + "--debug", "--plot", plot, + "--use-references", "yes"]) assert os.path.exists(contigs) assert os.path.exists(stitched_contigs) diff --git a/micall/utils/referencefull_contig_stitcher.py b/micall/utils/referencefull_contig_stitcher.py new file mode 100644 index 000000000..35027e53a --- /dev/null +++ b/micall/utils/referencefull_contig_stitcher.py @@ -0,0 +1,658 @@ +from typing import Iterable, Optional, Tuple, List, Dict, Literal, TypeVar, TextIO +from collections import defaultdict +import csv +import os +from dataclasses import replace +from math import ceil +from functools import reduce +from itertools import tee, islice, chain +from gotoh import align_it +from queue import LifoQueue +from Bio import Seq +import logging +from fractions import Fraction +from operator import itemgetter +from aligntools import CigarHit, connect_nonoverlapping_cigar_hits, drop_overlapping_cigar_hits, CigarActions + +from micall.core.project_config import ProjectConfig +from micall.core.plot_contigs import plot_stitcher_coverage +from micall.utils.contig_stitcher_context import context, StitcherContext +from micall.utils.contig_stitcher_contigs import GenotypedContig, AlignedContig +from micall.utils.consensus_aligner import align_consensus +import micall.utils.contig_stitcher_events as events + + +T = TypeVar("T") +logger = logging.getLogger(__name__) + + +def log(e: events.EventType) -> None: + context.get().emit(e) + logger.debug("%s", e) + + +def cut_query(self: GenotypedContig, cut_point: float) -> Tuple[GenotypedContig, GenotypedContig]: + """ Cuts query sequence in two parts with cut_point between them. """ + + cut_point = max(0.0, cut_point) + left = replace(self, name=None, seq=self.seq[:ceil(cut_point)]) + right = replace(self, name=None, seq=self.seq[ceil(cut_point):]) + return left, right + + +def cut_reference(self: AlignedContig, cut_point: float) -> Tuple[AlignedContig, AlignedContig]: + """ Cuts this alignment in two parts with cut_point between them. """ + + alignment_left, alignment_right = self.alignment.cut_reference(cut_point) + left = replace(self, name=None, alignment=alignment_left) + right = replace(self, name=None, alignment=alignment_right) + log(events.Cut(self, left, right, cut_point)) + return left, right + + +def lstrip(self: AlignedContig) -> AlignedContig: + """ + Trims the query sequence of the contig from its beginning up to the start of the + alignment. The CIGAR alignment is also updated to reflect the trimming. + """ + + alignment = self.alignment.lstrip_reference().lstrip_query() + q_remainder, query = cut_query(self, alignment.q_st - 0.5) + alignment = alignment.translate(0, -1 * alignment.q_st) + result = AlignedContig.make(query, alignment, self.strand) + log(events.LStrip(self, result)) + return result + + +def rstrip(self: AlignedContig) -> AlignedContig: + """ + Trims the query sequence of the contig from its end based on the end of the + alignment. The CIGAR alignment is also updated to reflect the trimming. + """ + + alignment = self.alignment.rstrip_reference().rstrip_query() + query, q_remainder = cut_query(self, alignment.q_ei + 0.5) + result = AlignedContig.make(query, alignment, self.strand) + log(events.RStrip(self, result)) + return result + + +def overlap(a: AlignedContig, b: AlignedContig) -> bool: + def intervals_overlap(x, y): + return x[0] <= y[1] and x[1] >= y[0] + + if a.group_ref != b.group_ref: + return False + + return intervals_overlap((a.alignment.r_st, a.alignment.r_ei), + (b.alignment.r_st, b.alignment.r_ei)) + + +def munge(self: AlignedContig, other: AlignedContig) -> AlignedContig: + """ + Combines two adjacent contigs into a single contig by joining their + query sequences and alignments. + """ + + match_fraction = min(self.match_fraction, other.match_fraction) + ref_name = max([self, other], key=lambda x: x.alignment.ref_length).ref_name + query = GenotypedContig(seq=self.seq + other.seq, + name=None, + ref_name=ref_name, + group_ref=self.group_ref, + ref_seq=self.ref_seq, + match_fraction=match_fraction) + + self_alignment = self.alignment + other_alignment = \ + other.alignment.translate( + query_delta=(-1 * other.alignment.q_st + self.alignment.q_ei + 1), + reference_delta=0) + alignment = self_alignment.connect(other_alignment) + + ret = AlignedContig.make(query=query, alignment=alignment, strand=self.strand) + log(events.Munge(self, other, ret)) + return ret + + +def sliding_window(sequence: Iterable[T]) -> Iterable[Tuple[Optional[T], T, Optional[T]]]: + """ + Generate a three-element sliding window of a sequence. + + Each element generated contains a tuple with the previous item (None if the first item), + the current item, and the next item (None if the last item) in the sequence. + """ + + a, b, c = tee(sequence, 3) + prevs = chain([None], a) + nexts = chain(islice(c, 1, None), [None]) + return zip(prevs, b, nexts) + + +def combine_contigs(parts: List[AlignedContig]) -> AlignedContig: + """ + Combine a list of contigs into a single AlignedContig by trimming and merging overlapping parts. + + Left-trimming and right-trimming occur at any shared overlapping points + between adjacent parts. munge() is used to combine contiguous parts without overlap. + """ + + stripped_parts = [] + for prev_part, part, next_part in sliding_window(parts): + if prev_part is not None: + part = lstrip(part) + if next_part is not None: + part = rstrip(part) + stripped_parts.append(part) + + ret = reduce(munge, stripped_parts) + log(events.Combine(stripped_parts, ret)) + return ret + + +def align_to_reference(contig: GenotypedContig) -> Iterable[GenotypedContig]: + """ + Align a single Contig to its reference sequence, producing potentially multiple aligned contigs. + + If the reference sequence (ref_seq) is unavailable, the contig is returned unaltered. + Otherwise, alignments are performed and contigs corresponding to each alignment are yielded. + """ + + if contig.ref_seq is None: + log(events.NoRef(contig)) + yield contig + return + + alignments, _algo = align_consensus(contig.ref_seq, contig.seq) + hits = [x.to_cigar_hit() for x in alignments] + strands: List[Literal["forward", "reverse"]] = ["forward" if x.strand == 1 else "reverse" for x in alignments] + + for i, (hit, strand) in enumerate(zip(hits, strands)): + log(events.InitialHit(contig, i, hit, strand)) + + if not hits: + log(events.ZeroHits(contig)) + yield contig + return + + if len(set(strands)) > 1: + log(events.StrandConflict(contig)) + yield contig + return + + strand = strands[0] + if strand == "reverse": + rc = str(Seq.Seq(contig.seq).reverse_complement()) + original_contig = contig + new_contig = replace(contig, seq=rc) + contig = new_contig + hits = [replace(hit, q_st=len(rc)-hit.q_ei-1, q_ei=len(rc)-hit.q_st-1) for hit in hits] + + log(events.ReverseComplement(original_contig, new_contig)) + for i, (hit, strand) in enumerate(zip(hits, strands)): + log(events.InitialHit(contig, i, hit, strand)) + + def quality(x: CigarHit): + mlen = sum(1 for x in x.cigar.relax().iterate_operations() + if x == CigarActions.MATCH) + return (mlen, x.ref_length) + + filtered = list(drop_overlapping_cigar_hits(hits, quality)) + connected = list(connect_nonoverlapping_cigar_hits(filtered)) + log(events.HitNumber(contig, list(zip(hits, strands)), connected)) + + for i, single_hit in enumerate(connected): + query = replace(contig, name=None) + part = AlignedContig.make(query, single_hit, strand) + log(events.ConnectedHit(contig, part, i)) + yield part + + +def strip_conflicting_mappings(contigs: Iterable[GenotypedContig]) -> Iterable[GenotypedContig]: + contigs = list(contigs) + names = {contig.id: contig for contig in contigs} + + def get_indexes(id: int) -> Tuple[int, int]: + contig = names[id] + if isinstance(contig, AlignedContig): + return contig.alignment.q_st, contig.alignment.r_st + else: + return -1, -1 + + reference_sorted = list(sorted(names.keys(), key=lambda id: get_indexes(id)[1])) + query_sorted = list(sorted(names.keys(), key=lambda id: get_indexes(id)[0])) + + def is_out_of_order(id: int) -> bool: + return reference_sorted.index(id) != query_sorted.index(id) + + sorted_by_query = sorted(contigs, key=lambda contig: get_indexes(contig.id)) + for prev_contig, contig, next_contig in sliding_window(sorted_by_query): + if isinstance(contig, AlignedContig): + original = contig + start = prev_contig.alignment.q_ei + 1 if isinstance(prev_contig, AlignedContig) else 0 + end = next_contig.alignment.q_st - 1 if isinstance(next_contig, AlignedContig) else len(contig.seq) - 1 + + if prev_contig is not None or is_out_of_order(original.id): + contig = lstrip(contig) + log(events.InitialStrip(original, start, original.alignment.q_st - 1)) + if next_contig is not None or is_out_of_order(original.id): + contig = rstrip(contig) + log(events.InitialStrip(original, original.alignment.q_ei + 1, end)) + + yield contig + + +def align_all_to_reference(contigs: Iterable[GenotypedContig]) -> Iterable[GenotypedContig]: + """ + Align multiple contigs to their respective reference sequences. + + Applies align_to_reference to each contig in the given collection, + flattening the result into a single list. + """ + + groups = map(align_to_reference, contigs) + groups = map(strip_conflicting_mappings, groups) + for group in groups: + yield from group + + +def align_queries(seq1: str, seq2: str) -> Tuple[str, str]: + """ + Globally align two query sequences against each other + and return the resulting aligned sequences in MSA format. + """ + + gap_open_penalty = 15 + gap_extend_penalty = 3 + use_terminal_gap_penalty = 1 + aseq1, aseq2, score = \ + align_it( + seq1, seq2, + gap_open_penalty, + gap_extend_penalty, + use_terminal_gap_penalty) + + return aseq1, aseq2 + + +def find_all_overlapping_contigs(self: AlignedContig, aligned_contigs): + """ + Yield all contigs from a collection that overlap with a given contig. + Contigs are considered overlapping if they have overlapping intervals on the same reference genome. + """ + + for other in aligned_contigs: + if overlap(self, other): + yield other + + +def find_overlapping_contig(self: AlignedContig, aligned_contigs): + """ + Find the single contig in a collection that overlaps the most with a given contig. + It returns the contig with the maximum overlapped reference length with the given contig (self). + """ + + every = find_all_overlapping_contigs(self, aligned_contigs) + return max(every, key=lambda other: other.alignment.ref_length if other else 0, default=None) + + +def calculate_concordance(left: str, right: str) -> List[Fraction]: + """ + Calculate concordance for two given sequences using a sliding average. + + The function compares the two strings character by character, simultaneously from + both left to right and right to left, calculating a score that represents a moving + average of matches at each position. If characters match at a given position, + a score of 1 is added; otherwise, a score of 0 is added. The score is then + averaged with the previous scores using a weighted sliding average where the + current score has a weight of 1/3 and the accumulated score has a weight of 2/3. + This sliding average score is halved and then processed again, but in reverse direction. + + :param left: string representing first sequence + :param right: string representing second sequence + :return: list representing concordance ratio for each position + """ + + if len(left) != len(right): + raise ValueError("Can only calculate concordance for same sized sequences") + + result: List[Fraction] = [Fraction(0)] * len(left) + + def slide(start, end): + scores_sum = Fraction(0) + inputs = list(zip(left, right)) + increment = 1 if start <= end else -1 + + for i in range(start, end, increment): + (a, b) = inputs[i] + current = Fraction(1) if a == b else Fraction(0) + scores_sum = (scores_sum * 2 / 3 + current * 1 / 3) + result[i] += scores_sum / 2 + + # Slide forward, then in reverse, adding the scores at each position. + slide(0, len(left)) + slide(len(left) - 1, -1) + + return result + + +def disambiguate_concordance(concordance: List[Fraction]) -> Iterable[Tuple[Fraction, int]]: + for i, x in enumerate(concordance): + global_rank = i if i < len(concordance) / 2 else len(concordance) - i - 1 + yield x, global_rank + + +def concordance_to_cut_points(left_overlap, right_overlap, aligned_left, aligned_right, concordance): + """ Determine optimal cut points for stitching based on sequence concordance in the overlap region. """ + + concordance_d = list(disambiguate_concordance(concordance)) + sorted_concordance_indexes = [i for i, v in sorted(enumerate(concordance_d), + key=itemgetter(1), + reverse=True, + )] + + def remove_dashes(s: str): + return s.replace('-', '') + + for max_concordance_index in sorted_concordance_indexes: + aligned_left_q_index = len(remove_dashes(aligned_left[:max_concordance_index])) + aligned_right_q_index = right_overlap.alignment.query_length - \ + len(remove_dashes(aligned_right[max_concordance_index:])) + 1 + aligned_left_r_index = left_overlap.alignment.coordinate_mapping.query_to_ref.left_max(aligned_left_q_index) + if aligned_left_r_index is None: + aligned_left_r_index = left_overlap.alignment.r_st - 1 + aligned_right_r_index = right_overlap.alignment.coordinate_mapping.query_to_ref.right_min(aligned_right_q_index) + if aligned_right_r_index is None: + aligned_right_r_index = right_overlap.alignment.r_ei + 1 + if aligned_right_r_index > aligned_left_r_index: + return aligned_left_r_index + 0.5, aligned_right_r_index - 0.5, max_concordance_index + + return left_overlap.alignment.r_st - 1 + 0.5, right_overlap.alignment.r_ei + 1 - 0.5, 0 + + +def stitch_2_contigs(left, right): + """ + Stitch two contigs together into a single coherent contig. + + The function handles the overlap by cutting both contigs into segments, aligning the + overlapping segments, and then choosing the optimal stitching points based on sequence + concordance. Non-overlapping segments are retained as is. + """ + + # Cut in 4 parts. + left_remainder, left_overlap = cut_reference(left, right.alignment.r_st - 0.5) + right_overlap, right_remainder = cut_reference(right, left.alignment.r_ei + 0.5) + left_overlap = lstrip(rstrip(left_overlap)) + right_overlap = lstrip(rstrip(right_overlap)) + left_remainder = rstrip(left_remainder) + right_remainder = lstrip(right_remainder) + log(events.StitchCut(left, right, left_overlap, right_overlap, left_remainder, right_remainder)) + + # Align overlapping parts, then recombine based on concordance. + aligned_left, aligned_right = align_queries(left_overlap.seq, right_overlap.seq) + concordance = calculate_concordance(aligned_left, aligned_right) + aligned_left_cutpoint, aligned_right_cutpoint, max_concordance_index = \ + concordance_to_cut_points(left_overlap, right_overlap, aligned_left, aligned_right, concordance) + left_overlap_take, left_overlap_drop = cut_reference(left_overlap, aligned_left_cutpoint) + right_overlap_drop, right_overlap_take = cut_reference(right_overlap, aligned_right_cutpoint) + + # Log it. + average_concordance = Fraction(sum(concordance) / (len(concordance) or 1)) + cut_point_location_scaled = max_concordance_index / (((len(concordance) or 1) - 1) or 1) + log(events.Overlap(left, right, left_overlap, right_overlap, + left_remainder, right_remainder, left_overlap_take, + right_overlap_take, concordance, average_concordance, + max_concordance_index, cut_point_location_scaled)) + + return combine_contigs([left_remainder, left_overlap_take, right_overlap_take, right_remainder]) + + +def combine_overlaps(contigs: List[AlignedContig]) -> Iterable[AlignedContig]: + """ + Repeatedly combine all overlapping aligned contigs into an iterable collection of contiguous AlignedContigs. + It proceeds by iterating through sorted contigs and stitching any overlapping ones until none are left. + """ + + # Going left-to-right through aligned contigs. + contigs = list(sorted(contigs, key=lambda x: x.alignment.r_st)) + while contigs: + current = contigs.pop(0) + + # Find overlap. If there isn't one - we are done with the current contig. + overlapping_contig = find_overlapping_contig(current, contigs) + if not overlapping_contig: + log(events.NoOverlap(current)) + yield current + continue + + # Replace two contigs by their stitched version, then loop with it. + new_contig = stitch_2_contigs(current, overlapping_contig) + contigs.remove(overlapping_contig) + contigs.insert(0, new_contig) + log(events.Stitch(current, overlapping_contig, new_contig)) + + +def merge_intervals(intervals: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + """ + Merge overlapping and adjacent intervals. + Note that intervals are inclusive. + + :param intervals: A list of intervals [start, end] where 'start' and 'end' are integers. + :return: A list of merged intervals. + """ + + if not intervals: + return [] + + # Sort intervals by their starting values + sorted_intervals = sorted(intervals, key=lambda x: x[0]) + + merged_intervals = [sorted_intervals[0]] + for current in sorted_intervals[1:]: + current_start, current_end = current + last_start, last_end = merged_intervals[-1] + if current_start <= last_end + 1: + # Extend the last interval if there is an overlap or if they are adjacent + merged_intervals[-1] = (min(last_start, current_start), max(last_end, current_end)) + else: + # Add the current interval if there is no overlap + merged_intervals.append(current) + + return merged_intervals + + +def find_covered_contig(contigs: List[AlignedContig]) -> Tuple[Optional[AlignedContig], List[AlignedContig]]: + """ + Find and return the first contig that is completely covered by other contigs. + + :param contigs: List of all aligned contigs to be considered. + :return: An AlignedContig if there is one completely covered by others, None otherwise. + """ + + def calculate_cumulative_coverage(others) -> List[Tuple[int, int]]: + intervals = [(contig.alignment.r_st, contig.alignment.r_ei) for contig in others] + merged_intervals = merge_intervals(intervals) + return merged_intervals + + for current in contigs: + current_interval = (current.alignment.r_st, current.alignment.r_ei) + + # Create a map of cumulative coverage for contigs + overlaping_contigs = [x for x in contigs if x.id != current.id and overlap(current, x)] + cumulative_coverage = calculate_cumulative_coverage(overlaping_contigs) + + # Check if the current contig is covered by the cumulative coverage intervals + if any((cover_interval[0] <= current_interval[0] and cover_interval[1] >= current_interval[1]) + for cover_interval in cumulative_coverage): + return current, overlaping_contigs + + return None, [] + + +def drop_completely_covered(contigs: List[AlignedContig]) -> List[AlignedContig]: + """ Filter out all contigs that are contained within other contigs. """ + + contigs = contigs[:] + while contigs: + covered, covering = find_covered_contig(contigs) + if covered: + contigs.remove(covered) + log(events.Drop(covered, covering)) + else: + break + + return contigs + + +def split_contigs_with_gaps(contigs: List[AlignedContig]) -> List[AlignedContig]: + """ + Split contigs at large gaps if those gaps are covered by other contigs in the list. + + A gap within a contig is considered large based on a pre-defined threshold. If another contig aligns + within that gap's range, the contig is split into two around the midpoint of the gap. + """ + + def covered_by(gap, other): + # Check if any 1 reference coordinate in gap is mapped in `other`. + gap_coords = gap.coordinate_mapping.ref_to_query.domain + cover_coords = set(other.alignment.coordinate_mapping.ref_to_query.keys()) + return not gap_coords.isdisjoint(cover_coords) + + def covered(self, gap): + return any(covered_by(gap, other) for other in contigs if other != self) + + def significant(gap): + # noinspection PyLongLine + # The size of the gap is unavoidably, to some point, arbitrary. Here we tried to adjust it to common gaps in HIV, as HIV is the primary test subject in MiCall. A notable feature of HIV-1 reverse transcription is the appearance of periodic deletions of approximately 21 nucleotides. These deletions have been reported to occur in the HIV-1 genome and are thought to be influenced by the structure of the viral RNA. Specifically, the secondary structures and foldings of the RNA can lead to pause sites for the reverse transcriptase, resulting in staggered alignment when the enzyme slips. This misalignment can cause the reverse transcriptase to "jump," leading to deletions in the newly synthesized DNA. The unusually high frequency of about 21-nucleotide deletions is believed to correspond to the pitch of the RNA helix, which reflects the spatial arrangement of the RNA strands. The 21 nucleotide cycle is an average measure and is thought to be associated with the length of one turn of the RNA helix, meaning that when reverse transcriptase slips and reattaches, it often does so one helical turn away from the original site. # noqa: E501 + return gap.ref_length > 21 + + def try_split(self: AlignedContig): + for gap in self.alignment.deletions(): + if not significant(gap): + # Really we do not want to split on every little deletion + # because that would mean that we would need to stitch + # overlaps around them. + # And we are likely to lose quality with every stitching operation. + # By skipping we assert that this gap is aligner's fault. + log(events.IgnoreGap(self, gap)) + continue + + if covered(self, gap): + midpoint = gap.r_st + (gap.r_ei - gap.r_st) / 2 + self.alignment.epsilon + left_part, right_part = cut_reference(self, midpoint) + left_part = rstrip(left_part) + right_part = lstrip(right_part) + + contigs.remove(self) + contigs.append(left_part) + contigs.append(right_part) + process_queue.put(right_part) + log(events.SplitGap(self, gap, left_part, right_part)) + return + + process_queue: LifoQueue = LifoQueue() + for contig in contigs: + process_queue.put(contig) + + while not process_queue.empty(): + contig = process_queue.get() + try_split(contig) + + return contigs + + +def stitch_contigs(contigs: Iterable[GenotypedContig]) -> Iterable[GenotypedContig]: + contigs = list(contigs) + for contig in contigs: + log(events.Intro(contig)) + contig.register() + + maybe_aligned = list(align_all_to_reference(contigs)) + + # Contigs that did not align do not need any more processing + yield from (x for x in maybe_aligned if not isinstance(x, AlignedContig)) + aligned = [x for x in maybe_aligned if isinstance(x, AlignedContig)] + + aligned = split_contigs_with_gaps(aligned) + aligned = drop_completely_covered(aligned) + yield from combine_overlaps(aligned) + + +GroupRef = Optional[str] + + +def stitch_consensus(contigs: Iterable[GenotypedContig]) -> Iterable[GenotypedContig]: + contigs = list(stitch_contigs(contigs)) + consensus_parts: Dict[GroupRef, List[AlignedContig]] = defaultdict(list) + + for contig in contigs: + if isinstance(contig, AlignedContig): + consensus_parts[contig.group_ref].append(contig) + else: + yield contig + + def combine(group_ref): + ctgs = sorted(consensus_parts[group_ref], key=lambda x: x.alignment.r_st) + result = combine_contigs(ctgs) + log(events.FinalCombine(ctgs, result)) + return result + + yield from map(combine, consensus_parts) + + +def write_contigs(output_csv: TextIO, contigs: Iterable[GenotypedContig]): + writer = csv.DictWriter(output_csv, + ['ref', 'match', 'group_ref', 'contig'], + lineterminator=os.linesep) + writer.writeheader() + for contig in contigs: + writer.writerow(dict(ref=contig.ref_name, + match=contig.match_fraction, + group_ref=contig.group_ref, + contig=contig.seq)) + + output_csv.flush() + + +def referencefull_contig_stitcher(input_csv: TextIO, + output_csv: TextIO, + stitcher_plot_path: Optional[str], + ) -> int: + with StitcherContext.fresh() as ctx: + contigs = list(read_referencefull_contigs(input_csv)) + + if output_csv is not None or stitcher_plot_path is not None: + contigs = list(stitch_consensus(contigs)) + + if output_csv is not None: + write_contigs(output_csv, contigs) + + if stitcher_plot_path is not None: + plot_stitcher_coverage(ctx.events, stitcher_plot_path) + + return len(contigs) + + +def read_referencefull_contigs(input_csv: TextIO) -> Iterable[GenotypedContig]: + projects = ProjectConfig.loadDefault() + + for row in csv.DictReader(input_csv): + seq = row['contig'] + ref_name = row['ref'] + group_ref = row['group_ref'] + match_fraction = float(row['match']) + + try: + ref_seq = projects.getGenotypeReference(group_ref) + except KeyError: + try: + ref_seq = projects.getReference(group_ref) + except KeyError: + ref_seq = None + + yield GenotypedContig(name=None, + seq=seq, + ref_name=ref_name, + group_ref=group_ref, + ref_seq=str(ref_seq) if ref_seq is not None else None, + match_fraction=match_fraction) diff --git a/micall/utils/referenceless_contig_stitcher.py b/micall/utils/referenceless_contig_stitcher.py index 628e16447..e5cdc7056 100644 --- a/micall/utils/referenceless_contig_stitcher.py +++ b/micall/utils/referenceless_contig_stitcher.py @@ -1,6 +1,7 @@ -from typing import Iterable, Iterator, Optional, FrozenSet, Tuple, Sequence +from typing import Iterable, Iterator, Optional, FrozenSet, Tuple, Sequence, TextIO from dataclasses import dataclass from fractions import Fraction +import csv from micall.utils.contig_stitcher_contigs import Contig from micall.utils.find_maximum_overlap import find_maximum_overlap @@ -139,3 +140,9 @@ def stitch_consensus(contigs: Iterable[Contig]) -> Iterable[Contig]: yield most_probable.whole remaining = tuple(contig for contig in remaining if not most_probable.has_contig(contig)) + + +def read_referenceless_contigs(input_csv: TextIO) -> Iterable[Contig]: + for row in csv.DictReader(input_csv): + seq = row['contig'] + yield Contig(name=None, seq=seq)