Skip to content

Commit

Permalink
Contig stitcher: fix a visualization of root combinations
Browse files Browse the repository at this point in the history
  • Loading branch information
Donaim committed Jan 26, 2024
1 parent e9617cb commit 51d8165
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 40 deletions.
5 changes: 2 additions & 3 deletions micall/core/contig_stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,7 @@ def align_to_reference(contig: GenotypedContig) -> Iterable[GenotypedContig]:
min(x.q_st, x.q_en - 1), max(x.q_st, x.q_en - 1)),
"forward" if x.strand == 1 else "reverse") for x in alignments]

connected = connect_cigar_hits(list(map(lambda p: p[0], hits_array))) if hits_array else []

connected = connect_cigar_hits([hit for hit, strand in hits_array]) if hits_array else []
if not connected:
logger.debug("Contig %r not aligned - backend's choice.", contig.name)
context.get().emit(events.ZeroHits(contig))
Expand Down Expand Up @@ -284,7 +283,7 @@ def get_indexes(name: str) -> Tuple[int, int]:
def is_out_of_order(name: str) -> bool:
return reference_sorted.index(name) != query_sorted.index(name)

sorted_by_query = list(sorted(contigs, key=lambda contig: contig.alignment.q_st if isinstance(contig, AlignedContig) else -1))
sorted_by_query = sorted(contigs, key=lambda contig: get_indexes(contig.name))
for prev_contig, contig, next_contig in sliding_window(sorted_by_query):
if isinstance(contig, AlignedContig):
name = contig.name
Expand Down
40 changes: 24 additions & 16 deletions micall/core/plot_contigs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing
from typing import Dict, Tuple, List, Set, Iterable, NoReturn
from typing import Dict, Tuple, List, Set, Iterable, NoReturn, Literal
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, FileType
from collections import Counter, defaultdict
from csv import DictReader
Expand All @@ -20,7 +20,8 @@

from micall.core.project_config import ProjectConfig
from micall.utils.alignment_wrapper import align_nucs
from micall.core.contig_stitcher import Contig, GenotypedContig, AlignedContig
from micall.core.contig_stitcher import Contig, GenotypedContig, AlignedContig, sliding_window
from micall.utils.cigar_tools import CigarHit
import micall.utils.contig_stitcher_events as events


Expand Down Expand Up @@ -419,6 +420,7 @@ def build_stitcher_figure(logs: Iterable[events.EventType]) -> Figure:
overlap_lefttake_map: Dict[str, str] = {}
overlap_righttake_map: Dict[str, str] = {}
overlap_sibling_map: Dict[str, str] = {}
combine_list: List[str] = []
combine_left_edge: Dict[str, str] = {}
combine_right_edge: Dict[str, str] = {}
temporary: Set[str] = set()
Expand Down Expand Up @@ -577,6 +579,9 @@ def record_bad_contig(contig: GenotypedContig, lst: List[str]):
record_contig(event.left, [event.original])
record_contig(event.right, [event.original])
elif isinstance(event, events.Combine):
for contig in event.contigs:
combine_list.append(contig.name)

record_contig(event.result, event.contigs)
if event.contigs:
combine_left_edge[event.result.name] = event.contigs[0].name
Expand Down Expand Up @@ -617,6 +622,12 @@ def record_bad_contig(contig: GenotypedContig, lst: List[str]):
if len(children) > 2:
children_meet_points.append(contig_name)

def hits_to_insertions(hits: List[CigarHit]):
for hit in hits:
yield CigarHit.from_default_alignment(q_st=0, q_ei=hit.q_st - 1, r_st=hit.q_st - 1, r_ei=hit.q_st - 2)
yield CigarHit.from_default_alignment(q_st=hit.q_ei + 1, q_ei=len(contig.seq) - 1, r_st=hit.q_ei + 1, r_ei=hit.q_ei)
yield from hit.insertions()

last_join_points_parent = {contig_name for join in children_join_points for contig_name in transitive_parent_graph.get(join, [])}
last_join_points = []
for contig_name in children_join_points:
Expand Down Expand Up @@ -675,13 +686,10 @@ def copy_takes_one_side(edge_table, overlap_xtake_map, overlap_xparent_map):
if contig_name in temporary:
continue

if contig_name in overlap_sibling_map:
if contig_name in combine_list:
finals = reduced_morphism_graph.get(contig_name, [contig_name])
if len(finals) == 1:
[final] = finals
parents = reduced_parent_graph.get(final, [])
if len(parents) == 1:
final_parts[final] = True
final_parts[finals[0]] = True

elif contig_name in bad_contigs:
final_parts[contig_name] = True
Expand All @@ -708,15 +716,15 @@ def copy_takes_one_side(edge_table, overlap_xtake_map, overlap_xparent_map):

final_parts[contig_name] = True

final_parent_mapping: Dict[str, List[str]] = {}
final_children_mapping: Dict[str, List[str]] = {}
for parent_name in sorted_roots:
children = []
for final_contig in final_parts:
if final_contig == parent_name or \
parent_name in reduced_parent_graph.get(final_contig, []):
children.append(final_contig)

final_parent_mapping[parent_name] = children
final_children_mapping[parent_name] = children

min_position, max_position = 1, 1
position_offset = 100
Expand All @@ -727,7 +735,7 @@ def copy_takes_one_side(edge_table, overlap_xtake_map, overlap_xparent_map):
max_position = max(max_position, len(contig.seq) + 3 * position_offset)

name_mappings = {}
for i, (parent, children) in enumerate(sorted(final_parent_mapping.items(), key=lambda p: p[0])):
for i, (parent, children) in enumerate(sorted(final_children_mapping.items(), key=lambda p: p[0])):
name_mappings[parent] = f"{i + 1}"
children = list(sorted(children, key=lambda name: query_position_map.get(name, -1)))
for k, child in enumerate(children):
Expand Down Expand Up @@ -766,7 +774,7 @@ def get_neighbour(part, lookup):
full_size_map: Dict[str, Tuple[int, int]] = {}

for parent_name in sorted_roots:
parts_names = final_parent_mapping[parent_name]
parts_names = final_children_mapping[parent_name]
parts = [contig_map[part] for part in parts_names]

for part in parts:
Expand Down Expand Up @@ -837,7 +845,7 @@ def get_contig_coordinates(contig: GenotypedContig) -> Tuple[int, int, int, int]
return (a_r_st, a_r_ei, f_r_st, f_r_ei)

def get_tracks(repeatset: Set[str], group_ref: str, contig_name: str) -> Iterable[Track]:
parts = final_parent_mapping[contig_name]
parts = final_children_mapping[contig_name]
for part_name in parts:
part = contig_map[part_name]

Expand All @@ -859,7 +867,7 @@ def get_tracks(repeatset: Set[str], group_ref: str, contig_name: str) -> Iterabl
yield Track(f_r_st, f_r_ei, label=f"{indexes}")

def get_arrows(repeatset: Set[str], group_ref: str, contig_name: str, labels: bool) -> Iterable[Arrow]:
parts = final_parent_mapping[contig_name]
parts = final_children_mapping[contig_name]
for part_name in parts:
part = contig_map[part_name]

Expand Down Expand Up @@ -980,7 +988,7 @@ def get_all_arrows(group_ref: str, labels: bool) -> Iterable[Arrow]:
pos = position_offset / 2
figure.add(Track(pos, pos, h=40, label=label))
for parent_name in sorted_roots:
contigs = final_parent_mapping.get(parent_name, [])
contigs = final_children_mapping.get(parent_name, [])
for contig_name in contigs:
if contig_name not in discarded:
continue
Expand All @@ -1000,7 +1008,7 @@ def get_all_arrows(group_ref: str, labels: bool) -> Iterable[Arrow]:
pos = position_offset / 2
figure.add(Track(pos, pos, h=40, label=label))
for parent_name in sorted_roots:
contigs = final_parent_mapping.get(parent_name, [])
contigs = final_children_mapping.get(parent_name, [])
for contig_name in contigs:
if contig_name not in anomaly:
continue
Expand Down Expand Up @@ -1028,7 +1036,7 @@ def get_all_arrows(group_ref: str, labels: bool) -> Iterable[Arrow]:
pos = position_offset / 2
figure.add(Track(pos, pos, h=40, label=label))
for parent_name in sorted_roots:
contigs = final_parent_mapping.get(parent_name, [])
contigs = final_children_mapping.get(parent_name, [])
for contig_name in contigs:
if contig_name not in unknown:
continue
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
70 changes: 70 additions & 0 deletions micall/tests/data/stitcher_plots/test_gap_around_big_insertion.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
72 changes: 72 additions & 0 deletions micall/tests/test_contig_stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def visualizer(request, tmp_path):
os.makedirs(plots_dir, exist_ok=True)
path_to_expected = os.path.join(plots_dir, plot_name)
path_to_produced = os.path.join(tmp_path, plot_name)
# path_to_produced = path_to_expected

def check():
logs = stitcher.context.get().events
Expand Down Expand Up @@ -687,6 +688,77 @@ def test_partial_align_consensus_multiple_overlaping_sequences(exact_aligner, vi
assert len(visualizer().elements) > len(contigs)


def test_big_insertion_in_a_single_contig(exact_aligner, visualizer):
# Scenario: Single contig produces many alignments.

ref_seq='A' * 10 + 'B' * 20 + 'C' * 10

contigs = [
GenotypedContig(name='a',
seq='B' * 10 + 'D' * 100 + 'B' * 10,
ref_name='testref',
group_ref='testref',
ref_seq=ref_seq,
match_fraction=0.3,
),
]

results = list(stitch_consensus(contigs))
assert len(results) == 1
assert results[0].seq == contigs[0].seq

assert len(visualizer().elements) > len(contigs)


def test_big_insertion_in_a_single_contig_2(exact_aligner, visualizer):
# Scenario: Single contig produces many alignments.

ref_seq='A' * 10 + 'B' * 20 + 'C' * 10

contigs = [
GenotypedContig(name='a',
seq='A' * 10 + 'D' * 100 + 'C' * 10,
ref_name='testref',
group_ref='testref',
ref_seq=ref_seq,
match_fraction=0.3,
),
]

results = list(stitch_consensus(contigs))
assert len(results) == 1
assert results[0].seq == contigs[0].seq

assert len(visualizer().elements) > len(contigs)


def test_gap_around_big_insertion(exact_aligner, visualizer):
# Scenario: Contig is split around its gap, then stripped.

ref_seq='A' * 10 + 'B' * 20 + 'C' * 10

contigs = [
GenotypedContig(name='a',
seq='A' * 10 + 'D' * 100 + 'C' * 10,
ref_name='testref',
group_ref='testref',
ref_seq=ref_seq,
match_fraction=0.3,
),
GenotypedContig(name='b',
seq='B' * 20,
ref_name='testref',
group_ref='testref',
ref_seq=ref_seq,
match_fraction=0.3,
),
]

results = list(stitch_consensus(contigs))
assert len(results) == 1
assert len(visualizer().elements) > len(contigs)


def test_main_invocation(exact_aligner, tmp_path, hcv_db):
pwd = os.path.dirname(__file__)
contigs = os.path.join(pwd, "data", "exact_parts_contigs.csv")
Expand Down
Loading

0 comments on commit 51d8165

Please sign in to comment.