Skip to content

Commit

Permalink
Improve scoring in referenceless
Browse files Browse the repository at this point in the history
  • Loading branch information
Donaim committed Jan 10, 2025
1 parent 33288ca commit 4ea1fd5
Showing 1 changed file with 83 additions and 40 deletions.
123 changes: 83 additions & 40 deletions micall/utils/referenceless_contig_stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,71 @@ def map_overlap(self, overlap: str) -> Iterator[Alignment]:
yield x


@dataclass(frozen=True)
class Score:
pessimisstic_probability: Fraction

# Lower is better. This is an estimated probability that
# all the components in this path came together by accident.
overall_probability: Fraction
path_lenth: int

def combine(self, prob: Fraction) -> 'Score':
return Score(pessimisstic_probability=min(prob, self.pessimisstic_probability),
overall_probability=prob * self.pessimisstic_probability,
path_lenth=1 + self.path_lenth,
)

@staticmethod
def initial() -> 'Score':
return Score(pessimisstic_probability=ACCEPTABLE_STITCHING_PROB,
overall_probability=Fraction(1),
path_lenth=0)

def __lt__(self, other: 'Score') -> bool:
if (1 - self.pessimisstic_probability) < (1 - other.pessimisstic_probability):
return True
if (1 - self.overall_probability) < (1 - other.overall_probability):
return True
if self.path_lenth < other.path_lenth:
return True
return False

def __le__(self, other: 'Score') -> bool:
if (1 - self.pessimisstic_probability) <= (1 - other.pessimisstic_probability):
return True
if (1 - self.overall_probability) <= (1 - other.overall_probability):
return True
if self.path_lenth <= other.path_lenth:
return True
return False

def __gt__(self, other: 'Score') -> bool:
if (1 - self.pessimisstic_probability) > (1 - other.pessimisstic_probability):
return True
if (1 - self.overall_probability) > (1 - other.overall_probability):
return True
if self.path_lenth > other.path_lenth:
return True
return False

def __ge__(self, other: 'Score') -> bool:
if (1 - self.pessimisstic_probability) >= (1 - other.pessimisstic_probability):
return True
if (1 - self.overall_probability) >= (1 - other.overall_probability):
return True
if self.path_lenth >= other.path_lenth:
return True
return False


@dataclass(frozen=True)
class ContigsPath:
# Contig representing all combined contigs in the path.
whole: ContigWithAligner

# Id's of contigs that comprise this path.
parts_ids: FrozenSet[int]

# Lower is better. This is an estimated probability that
# all the components in this path came together by accident.
probability: Fraction
pessimisstic_probability: Fraction

def score(self) -> Tuple[Fraction, Fraction, int]:
return (1-self.pessimisstic_probability, 1-self.probability, len(self.parts_ids))
score: Score

def has_contig(self, contig: Contig) -> bool:
return contig.id in self.parts_ids
Expand All @@ -65,9 +115,12 @@ def is_empty(self) -> bool:

@staticmethod
def empty() -> 'ContigsPath':
return ContigsPath(ContigWithAligner.empty(), frozenset(),
probability=Fraction(1),
pessimisstic_probability=ACCEPTABLE_STITCHING_PROB)
return ContigsPath(ContigWithAligner.empty(), frozenset(), Score.initial())


@dataclass
class MinimumAcceptableScore:
value: Score


@dataclass(frozen=True)
Expand Down Expand Up @@ -106,23 +159,16 @@ def get_overlap(finder: OverlapFinder, left: ContigWithAligner, right: ContigWit
return ret


def combine_probability(current: Fraction, new: Fraction) -> Fraction:
return current * new


TRY_COMBINE_CACHE: MutableMapping[
Tuple[ContigId, ContigId],
Optional[Tuple[ContigWithAligner, Fraction]]] = {}


def try_combine_contigs(finder: OverlapFinder,
current_prob: Fraction,
max_acceptable_prob: Fraction,
current_score: Score,
min_acceptable_score: MinimumAcceptableScore,
a: ContigWithAligner, b: ContigWithAligner,
) -> Optional[Tuple[ContigWithAligner, Fraction]]:
# TODO: Memoize this function.
# Two-layer caching seems most optimal:
# first by key=contig.id, then by key=contig.seq.

if len(b.seq) == 0:
return (a, Fraction(1))
Expand All @@ -145,7 +191,7 @@ def try_combine_contigs(finder: OverlapFinder,

optimistic_number_of_matches = overlap.size
optimistic_result_probability = calc_overlap_pvalue(L=overlap.size, M=optimistic_number_of_matches)
if combine_probability(optimistic_result_probability, current_prob) > max_acceptable_prob:
if current_score.combine(optimistic_result_probability) < min_acceptable_score.value:
return None

left_initial_overlap = left.seq[len(left.seq) - abs(shift):(len(left.seq) - abs(shift) + len(right.seq))]
Expand Down Expand Up @@ -193,7 +239,7 @@ def try_combine_contigs(finder: OverlapFinder,
in zip(aligned_left, aligned_right)
if x == y and x != '-')
result_probability = calc_overlap_pvalue(L=len(left_overlap), M=number_of_matches)
if combine_probability(current_prob, result_probability) > max_acceptable_prob:
if current_score.combine(result_probability) < min_acceptable_score.value:
return None

is_covered = len(right.seq) < abs(shift)
Expand Down Expand Up @@ -221,42 +267,41 @@ def try_combine_contigs(finder: OverlapFinder,


def extend_by_1(finder: OverlapFinder,
max_acceptable_prob: Fraction,
min_acceptable_score: MinimumAcceptableScore,
path: ContigsPath,
candidate: ContigWithAligner,
) -> Iterator[ContigsPath]:
if path.has_contig(candidate):
return

combination = try_combine_contigs(finder, path.probability, max_acceptable_prob, path.whole, candidate)
combination = try_combine_contigs(finder, path.score, min_acceptable_score, path.whole, candidate)
if combination is None:
return

(combined, prob) = combination
probability = combine_probability(path.probability, prob)
pessimisstic_probability = min(path.pessimisstic_probability, prob)
score = path.score.combine(prob)
new_elements = path.parts_ids.union([candidate.id])
new_path = ContigsPath(combined, new_elements, probability, pessimisstic_probability)
new_path = ContigsPath(combined, new_elements, score)
yield new_path


def calc_extension(finder: OverlapFinder,
max_acceptable_prob: Fraction,
min_acceptable_score: MinimumAcceptableScore,
contigs: Sequence[ContigWithAligner],
path: ContigsPath,
) -> Iterator[ContigsPath]:

for contig in contigs:
yield from extend_by_1(finder, max_acceptable_prob, path, contig)
yield from extend_by_1(finder, min_acceptable_score, path, contig)


def calc_multiple_extensions(finder: OverlapFinder,
max_acceptable_prob: Fraction,
min_acceptable_score: MinimumAcceptableScore,
paths: Iterable[ContigsPath],
contigs: Sequence[ContigWithAligner],
) -> Iterator[ContigsPath]:
for path in paths:
yield from calc_extension(finder, max_acceptable_prob, contigs, path)
yield from calc_extension(finder, min_acceptable_score, contigs, path)


def filter_extensions(existing: MutableMapping[str, ContigsPath],
Expand All @@ -267,18 +312,18 @@ def filter_extensions(existing: MutableMapping[str, ContigsPath],
for path in extensions:
key = path.whole.seq
alternative = existing.get(key)
if alternative is None or path.score() > alternative.score():
if alternative is None or path.score > alternative.score:
existing[key] = path
ret[key] = path

yield from ret.values()


def calculate_all_paths(contigs: Sequence[ContigWithAligner]) -> Iterator[ContigsPath]:
max_acceptable_prob = ACCEPTABLE_STITCHING_PROB
min_acceptable_score = MinimumAcceptableScore(Score.initial())
existing: MutableMapping[str, ContigsPath] = {}
finder = OverlapFinder.make('ACTG')
extensions = calc_extension(finder, max_acceptable_prob, contigs, ContigsPath.empty())
extensions = calc_extension(finder, min_acceptable_score, contigs, ContigsPath.empty())
paths = tuple(filter_extensions(existing, extensions))
yield from paths

Expand All @@ -287,7 +332,7 @@ def calculate_all_paths(contigs: Sequence[ContigWithAligner]) -> Iterator[Contig
while paths:
logger.debug("Cycle %s started with %s paths.", cycle, len(paths))

extensions = calc_multiple_extensions(finder, max_acceptable_prob, paths, contigs)
extensions = calc_multiple_extensions(finder, min_acceptable_score, paths, contigs)
paths = tuple(filter_extensions(existing, extensions))

if paths:
Expand All @@ -305,17 +350,15 @@ def calculate_all_paths(contigs: Sequence[ContigWithAligner]) -> Iterator[Contig
# the most promising alternative at each point.
logger.debug("Dropping %s paths that have the lowest scores.",
len(paths) - MAX_ALTERNATIVES)
paths = tuple(sorted(paths, key=ContigsPath.score)[-MAX_ALTERNATIVES:])
paths = tuple(sorted(paths, key=lambda x: x.score)[-MAX_ALTERNATIVES:])

cycle += 1
yield from paths
if paths:
max_acceptable_prob = max(x.probability for x in paths)


def find_most_probable_path(contigs: Sequence[ContigWithAligner]) -> ContigsPath:
paths = calculate_all_paths(contigs)
return max(paths, key=ContigsPath.score)
return max(paths, key=lambda x: x.score)


def stitch_consensus(contigs: Iterable[ContigWithAligner]) -> Iterator[ContigWithAligner]:
Expand Down

0 comments on commit 4ea1fd5

Please sign in to comment.