From 4ea1fd5c1a7681f11a276daf911de7531fc4ff31 Mon Sep 17 00:00:00 2001 From: Vitaliy Mysak Date: Fri, 10 Jan 2025 14:14:44 -0800 Subject: [PATCH] Improve scoring in referenceless --- micall/utils/referenceless_contig_stitcher.py | 123 ++++++++++++------ 1 file changed, 83 insertions(+), 40 deletions(-) diff --git a/micall/utils/referenceless_contig_stitcher.py b/micall/utils/referenceless_contig_stitcher.py index 0f47a28bf..29af797e9 100644 --- a/micall/utils/referenceless_contig_stitcher.py +++ b/micall/utils/referenceless_contig_stitcher.py @@ -40,21 +40,71 @@ def map_overlap(self, overlap: str) -> Iterator[Alignment]: yield x +@dataclass(frozen=True) +class Score: + pessimisstic_probability: Fraction + + # Lower is better. This is an estimated probability that + # all the components in this path came together by accident. + overall_probability: Fraction + path_lenth: int + + def combine(self, prob: Fraction) -> 'Score': + return Score(pessimisstic_probability=min(prob, self.pessimisstic_probability), + overall_probability=prob * self.pessimisstic_probability, + path_lenth=1 + self.path_lenth, + ) + + @staticmethod + def initial() -> 'Score': + return Score(pessimisstic_probability=ACCEPTABLE_STITCHING_PROB, + overall_probability=Fraction(1), + path_lenth=0) + + def __lt__(self, other: 'Score') -> bool: + if (1 - self.pessimisstic_probability) < (1 - other.pessimisstic_probability): + return True + if (1 - self.overall_probability) < (1 - other.overall_probability): + return True + if self.path_lenth < other.path_lenth: + return True + return False + + def __le__(self, other: 'Score') -> bool: + if (1 - self.pessimisstic_probability) <= (1 - other.pessimisstic_probability): + return True + if (1 - self.overall_probability) <= (1 - other.overall_probability): + return True + if self.path_lenth <= other.path_lenth: + return True + return False + + def __gt__(self, other: 'Score') -> bool: + if (1 - self.pessimisstic_probability) > (1 - other.pessimisstic_probability): + return True + if (1 - self.overall_probability) > (1 - other.overall_probability): + return True + if self.path_lenth > other.path_lenth: + return True + return False + + def __ge__(self, other: 'Score') -> bool: + if (1 - self.pessimisstic_probability) >= (1 - other.pessimisstic_probability): + return True + if (1 - self.overall_probability) >= (1 - other.overall_probability): + return True + if self.path_lenth >= other.path_lenth: + return True + return False + + @dataclass(frozen=True) class ContigsPath: # Contig representing all combined contigs in the path. whole: ContigWithAligner - # Id's of contigs that comprise this path. parts_ids: FrozenSet[int] - - # Lower is better. This is an estimated probability that - # all the components in this path came together by accident. - probability: Fraction - pessimisstic_probability: Fraction - - def score(self) -> Tuple[Fraction, Fraction, int]: - return (1-self.pessimisstic_probability, 1-self.probability, len(self.parts_ids)) + score: Score def has_contig(self, contig: Contig) -> bool: return contig.id in self.parts_ids @@ -65,9 +115,12 @@ def is_empty(self) -> bool: @staticmethod def empty() -> 'ContigsPath': - return ContigsPath(ContigWithAligner.empty(), frozenset(), - probability=Fraction(1), - pessimisstic_probability=ACCEPTABLE_STITCHING_PROB) + return ContigsPath(ContigWithAligner.empty(), frozenset(), Score.initial()) + + +@dataclass +class MinimumAcceptableScore: + value: Score @dataclass(frozen=True) @@ -106,23 +159,16 @@ def get_overlap(finder: OverlapFinder, left: ContigWithAligner, right: ContigWit return ret -def combine_probability(current: Fraction, new: Fraction) -> Fraction: - return current * new - - TRY_COMBINE_CACHE: MutableMapping[ Tuple[ContigId, ContigId], Optional[Tuple[ContigWithAligner, Fraction]]] = {} def try_combine_contigs(finder: OverlapFinder, - current_prob: Fraction, - max_acceptable_prob: Fraction, + current_score: Score, + min_acceptable_score: MinimumAcceptableScore, a: ContigWithAligner, b: ContigWithAligner, ) -> Optional[Tuple[ContigWithAligner, Fraction]]: - # TODO: Memoize this function. - # Two-layer caching seems most optimal: - # first by key=contig.id, then by key=contig.seq. if len(b.seq) == 0: return (a, Fraction(1)) @@ -145,7 +191,7 @@ def try_combine_contigs(finder: OverlapFinder, optimistic_number_of_matches = overlap.size optimistic_result_probability = calc_overlap_pvalue(L=overlap.size, M=optimistic_number_of_matches) - if combine_probability(optimistic_result_probability, current_prob) > max_acceptable_prob: + if current_score.combine(optimistic_result_probability) < min_acceptable_score.value: return None left_initial_overlap = left.seq[len(left.seq) - abs(shift):(len(left.seq) - abs(shift) + len(right.seq))] @@ -193,7 +239,7 @@ def try_combine_contigs(finder: OverlapFinder, in zip(aligned_left, aligned_right) if x == y and x != '-') result_probability = calc_overlap_pvalue(L=len(left_overlap), M=number_of_matches) - if combine_probability(current_prob, result_probability) > max_acceptable_prob: + if current_score.combine(result_probability) < min_acceptable_score.value: return None is_covered = len(right.seq) < abs(shift) @@ -221,42 +267,41 @@ def try_combine_contigs(finder: OverlapFinder, def extend_by_1(finder: OverlapFinder, - max_acceptable_prob: Fraction, + min_acceptable_score: MinimumAcceptableScore, path: ContigsPath, candidate: ContigWithAligner, ) -> Iterator[ContigsPath]: if path.has_contig(candidate): return - combination = try_combine_contigs(finder, path.probability, max_acceptable_prob, path.whole, candidate) + combination = try_combine_contigs(finder, path.score, min_acceptable_score, path.whole, candidate) if combination is None: return (combined, prob) = combination - probability = combine_probability(path.probability, prob) - pessimisstic_probability = min(path.pessimisstic_probability, prob) + score = path.score.combine(prob) new_elements = path.parts_ids.union([candidate.id]) - new_path = ContigsPath(combined, new_elements, probability, pessimisstic_probability) + new_path = ContigsPath(combined, new_elements, score) yield new_path def calc_extension(finder: OverlapFinder, - max_acceptable_prob: Fraction, + min_acceptable_score: MinimumAcceptableScore, contigs: Sequence[ContigWithAligner], path: ContigsPath, ) -> Iterator[ContigsPath]: for contig in contigs: - yield from extend_by_1(finder, max_acceptable_prob, path, contig) + yield from extend_by_1(finder, min_acceptable_score, path, contig) def calc_multiple_extensions(finder: OverlapFinder, - max_acceptable_prob: Fraction, + min_acceptable_score: MinimumAcceptableScore, paths: Iterable[ContigsPath], contigs: Sequence[ContigWithAligner], ) -> Iterator[ContigsPath]: for path in paths: - yield from calc_extension(finder, max_acceptable_prob, contigs, path) + yield from calc_extension(finder, min_acceptable_score, contigs, path) def filter_extensions(existing: MutableMapping[str, ContigsPath], @@ -267,7 +312,7 @@ def filter_extensions(existing: MutableMapping[str, ContigsPath], for path in extensions: key = path.whole.seq alternative = existing.get(key) - if alternative is None or path.score() > alternative.score(): + if alternative is None or path.score > alternative.score: existing[key] = path ret[key] = path @@ -275,10 +320,10 @@ def filter_extensions(existing: MutableMapping[str, ContigsPath], def calculate_all_paths(contigs: Sequence[ContigWithAligner]) -> Iterator[ContigsPath]: - max_acceptable_prob = ACCEPTABLE_STITCHING_PROB + min_acceptable_score = MinimumAcceptableScore(Score.initial()) existing: MutableMapping[str, ContigsPath] = {} finder = OverlapFinder.make('ACTG') - extensions = calc_extension(finder, max_acceptable_prob, contigs, ContigsPath.empty()) + extensions = calc_extension(finder, min_acceptable_score, contigs, ContigsPath.empty()) paths = tuple(filter_extensions(existing, extensions)) yield from paths @@ -287,7 +332,7 @@ def calculate_all_paths(contigs: Sequence[ContigWithAligner]) -> Iterator[Contig while paths: logger.debug("Cycle %s started with %s paths.", cycle, len(paths)) - extensions = calc_multiple_extensions(finder, max_acceptable_prob, paths, contigs) + extensions = calc_multiple_extensions(finder, min_acceptable_score, paths, contigs) paths = tuple(filter_extensions(existing, extensions)) if paths: @@ -305,17 +350,15 @@ def calculate_all_paths(contigs: Sequence[ContigWithAligner]) -> Iterator[Contig # the most promising alternative at each point. logger.debug("Dropping %s paths that have the lowest scores.", len(paths) - MAX_ALTERNATIVES) - paths = tuple(sorted(paths, key=ContigsPath.score)[-MAX_ALTERNATIVES:]) + paths = tuple(sorted(paths, key=lambda x: x.score)[-MAX_ALTERNATIVES:]) cycle += 1 yield from paths - if paths: - max_acceptable_prob = max(x.probability for x in paths) def find_most_probable_path(contigs: Sequence[ContigWithAligner]) -> ContigsPath: paths = calculate_all_paths(contigs) - return max(paths, key=ContigsPath.score) + return max(paths, key=lambda x: x.score) def stitch_consensus(contigs: Iterable[ContigWithAligner]) -> Iterator[ContigWithAligner]: