Skip to content

Commit

Permalink
Contig stitcher: add missing type signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
Donaim committed Jan 24, 2024
1 parent 834c89b commit 3e6d20b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 58 deletions.
114 changes: 57 additions & 57 deletions micall/core/plot_contigs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import typing
from typing import Union, Dict, Tuple, List, Optional, Set, Iterable
from typing_extensions import Never
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, FileType
from collections import Counter, defaultdict
from csv import DictReader
Expand All @@ -19,6 +21,7 @@

from micall.core.project_config import ProjectConfig
from micall.utils.alignment_wrapper import align_nucs
from micall.core.contig_stitcher import Contig, GenotypedContig, AlignedContig
import micall.utils.contig_stitcher_events as events


Expand Down Expand Up @@ -395,19 +398,14 @@ def build_coverage_figure(genome_coverage_csv, blast_csv=None, use_concordance=F
return f


def plot_stitcher_coverage(logs, genome_coverage_svg_path):
def plot_stitcher_coverage(logs: Iterable[events.EventType], genome_coverage_svg_path: str):
f = build_stitcher_figure(logs)
f.show(w=970).save_svg(genome_coverage_svg_path, context=draw.Context(invert_y=True))
return f


from types import SimpleNamespace
from typing import Union, Dict, Tuple, List, Optional, Set
from micall.core.contig_stitcher import Contig, GenotypedContig, AlignedContig
import random

def build_stitcher_figure(logs) -> None:
contig_map: Dict[str, Contig] = {}
def build_stitcher_figure(logs: Iterable[events.EventType]) -> Figure:
contig_map: Dict[str, GenotypedContig] = {}
name_mappings: Dict[str, str] = {}
parent_graph: Dict[str, List[str]] = {}
morphism_graph: Dict[str, List[str]] = {}
Expand Down Expand Up @@ -515,7 +513,7 @@ def graph_sum(graph_a, graph_b):
def symmetric_closure(graph):
return graph_sum(graph, inverse_graph(graph))

def record_contig(contig: Contig, parents: List[Contig]):
def record_contig(contig: GenotypedContig, parents: List[GenotypedContig]):
contig_map[contig.name] = contig
if [contig.name] != [parent.name for parent in parents]:
for parent in parents:
Expand All @@ -532,7 +530,7 @@ def record_morphism(contig: Contig, original: Contig):
if contig.name not in lst:
lst.append(contig.name)

def record_bad_contig(contig: Contig, lst: List[Contig]):
def record_bad_contig(contig: GenotypedContig, lst: List[str]):
contig_map[contig.name] = contig
lst.append(contig.name)

Expand Down Expand Up @@ -585,11 +583,13 @@ def record_bad_contig(contig: Contig, lst: List[Contig]):
record_contig(event.right, [event.original])
elif isinstance(event, events.Combine):
record_contig(event.result, event.contigs)
combine_left_edge[event.result.name] = event.contigs[0].name
combine_right_edge[event.result.name] = event.contigs[-1].name
if event.contigs:
combine_left_edge[event.result.name] = event.contigs[0].name
combine_right_edge[event.result.name] = event.contigs[-1].name
elif isinstance(event, (events.IgnoreGap, events.NoOverlap)):
pass
else:
x: Never = event
raise RuntimeError(f"Unrecognized action or event: {event}")

group_refs = {contig.group_ref: len(contig.ref_seq) for contig in contig_map.values() if contig.ref_seq}
Expand All @@ -610,23 +610,23 @@ def record_bad_contig(contig: Contig, lst: List[Contig]):
eqv_morphism_graph = reflexive_closure(symmetric_closure(transitive_closure(morphism_graph)))
reduced_morphism_graph = reduced_closure(morphism_graph)

for contig in overlaps_list:
temporary.add(contig)
for child in transitive_children_graph.get(contig, []):
for contig_name in overlaps_list:
temporary.add(contig_name)
for child in transitive_children_graph.get(contig_name, []):
temporary.add(child)

for contig, parents in parent_graph.items():
for contig_name, parents in parent_graph.items():
if len(parents) > 2:
children_join_points.append(contig)
for contig, children in children_graph.items():
children_join_points.append(contig_name)
for contig_name, children in children_graph.items():
if len(children) > 2:
children_meet_points.append(contig)
children_meet_points.append(contig_name)

last_join_points_parent = {contig for join in children_join_points for contig in transitive_parent_graph.get(join, [])}
last_join_points_parent = {contig_name for join in children_join_points for contig_name in transitive_parent_graph.get(join, [])}
last_join_points = []
for contig in children_join_points:
if contig not in last_join_points_parent:
last_join_points.append(contig)
for contig_name in children_join_points:
if contig_name not in last_join_points_parent:
last_join_points.append(contig_name)

def set_query_position(contig: Contig):
if contig.name in query_position_map:
Expand All @@ -644,7 +644,7 @@ def set_query_position(contig: Contig):
if parent.name not in query_position_map:
set_query_position(parent)

average = sum(query_position_map[parent_name] for parent_name in parent_names) / len(parent_names)
average = round(sum(query_position_map[parent_name] for parent_name in parent_names) / len(parent_names))
query_position_map[contig.name] = average
else:
query_position_map[contig.name] = (contig.alignment.q_st + contig.alignment.q_ei) // 2
Expand All @@ -653,9 +653,9 @@ def set_query_position(contig: Contig):
set_query_position(contig)

# Closing `temporary'
for contig in contig_map:
if contig in temporary:
for clone in eqv_morphism_graph.get(contig, [contig]):
for contig_name in contig_map:
if contig_name in temporary:
for clone in eqv_morphism_graph.get(contig_name, [contig_name]):
temporary.add(clone)

def copy_takes_one_side(edge_table, overlap_xtake_map, overlap_xparent_map):
Expand All @@ -676,42 +676,42 @@ def copy_takes_one_side(edge_table, overlap_xtake_map, overlap_xparent_map):
while list(copy_takes_one_side(combine_left_edge, overlap_righttake_map, overlap_rightparent_map)): pass

final_parts: Dict[str, bool] = {}
for contig in contig_map:
if contig in temporary:
for contig_name in contig_map:
if contig_name in temporary:
continue

if contig in overlap_sibling_map:
finals = reduced_morphism_graph.get(contig, [contig])
if contig_name in overlap_sibling_map:
finals = reduced_morphism_graph.get(contig_name, [contig_name])
if len(finals) == 1:
[final] = finals
parents = reduced_parent_graph.get(final, [])
if len(parents) == 1:
final_parts[final] = True

elif contig in bad_contigs:
final_parts[contig] = True
elif contig_name in bad_contigs:
final_parts[contig_name] = True

for join in last_join_points + sorted_sinks:
parents = parent_graph.get(join, [join])
if not any(isinstance(contig_map[parent], AlignedContig) for parent in parents):
parents = [join]

for contig in parents:
for contig in reduced_morphism_graph.get(contig, [contig]):
if contig in bad_contigs:
for contig_name in parents:
for contig_name in reduced_morphism_graph.get(contig_name, [contig_name]):
if contig_name in bad_contigs:
continue

if any(contig in transitive_parent_graph.get(bad, []) for bad in bad_contigs):
if any(contig_name in transitive_parent_graph.get(bad, []) for bad in bad_contigs):
continue

if any(eqv in temporary for eqv in eqv_morphism_graph.get(contig, [contig])):
if any(eqv in temporary for eqv in eqv_morphism_graph.get(contig_name, [contig_name])):
continue

transitive_parent = eqv_parent_graph.get(contig, [contig])
transitive_parent = eqv_parent_graph.get(contig_name, [contig_name])
if any(parent in transitive_parent for parent in final_parts):
continue

final_parts[contig] = True
final_parts[contig_name] = True

final_parent_mapping: Dict[str, List[str]] = {}
for parent_name in sorted_roots:
Expand All @@ -725,7 +725,7 @@ def copy_takes_one_side(edge_table, overlap_xtake_map, overlap_xparent_map):

min_position, max_position = 1, 1
position_offset = 100
for contig in contig_map.values():
for _, contig in contig_map.items():
if isinstance(contig, GenotypedContig) and contig.ref_seq is not None:
max_position = max(max_position, len(contig.ref_seq) + 3 * position_offset)
else:
Expand All @@ -748,8 +748,8 @@ def copy_takes_one_side(edge_table, overlap_xtake_map, overlap_xparent_map):
k += 1
name_mappings[child] = f"{i + 1}.{k + 1}"

for contig, name in name_mappings.items():
logger.debug(f"Contig name {contig!r} is displayed as {name!r}.")
for contig_name, name in name_mappings.items():
logger.debug(f"Contig name {contig_name!r} is displayed as {name!r}.")

def get_neighbours(part, lookup):
for clone in eqv_morphism_graph.get(part.name, [part.name]):
Expand All @@ -771,8 +771,8 @@ def get_neighbour(part, lookup):
full_size_map: Dict[str, Tuple[int, int]] = {}

for parent_name in sorted_roots:
parts = final_parent_mapping[parent_name]
parts = [contig_map[part] for part in parts]
parts_names = final_parent_mapping[parent_name]
parts = [contig_map[part] for part in parts_names]

for part in parts:
if not isinstance(part, AlignedContig):
Expand Down Expand Up @@ -801,8 +801,8 @@ def get_neighbour(part, lookup):

aligned_size_map[part.name] = (r_st, r_ei)

sibling = ([overlap_sibling_map[name] for name in eqv_morphism_graph.get(part.name, [part.name]) if name in overlap_sibling_map] or [None])[0]
sibling = sibling and contig_map[sibling]
sibling_name = ([overlap_sibling_map[name] for name in eqv_morphism_graph.get(part.name, [part.name]) if name in overlap_sibling_map] or [""])[0]
sibling = sibling_name and contig_map[sibling_name]
prev_part = get_neighbour(sibling, overlap_lefttake_map)
next_part = get_neighbour(sibling, overlap_righttake_map)

Expand All @@ -820,7 +820,7 @@ def get_neighbour(part, lookup):

full_size_map[part.name] = (r_st, r_ei)

def get_contig_coordinates(contig):
def get_contig_coordinates(contig: GenotypedContig) -> Tuple[int, int, int, int]:
if isinstance(contig, AlignedContig):
r_st = position_offset + contig.alignment.r_st
r_ei = position_offset + contig.alignment.r_ei
Expand All @@ -841,7 +841,7 @@ def get_contig_coordinates(contig):
a_r_ei = f_r_ei
return (a_r_st, a_r_ei, f_r_st, f_r_ei)

def get_tracks(repeatset, group_ref, contig_name):
def get_tracks(repeatset: Set[str], group_ref: str, contig_name: str) -> Iterable[Track]:
parts = final_parent_mapping[contig_name]
for part_name in parts:
part = contig_map[part_name]
Expand All @@ -863,7 +863,7 @@ def get_tracks(repeatset, group_ref, contig_name):
(a_r_st, a_r_ei, f_r_st, f_r_ei) = get_contig_coordinates(part)
yield Track(f_r_st, f_r_ei, label=f"{indexes}")

def get_arrows(repeatset, group_ref, contig_name, labels):
def get_arrows(repeatset: Set[str], group_ref: str, contig_name: str, labels: bool) -> Iterable[Arrow]:
parts = final_parent_mapping[contig_name]
for part_name in parts:
part = contig_map[part_name]
Expand All @@ -890,8 +890,8 @@ def get_arrows(repeatset, group_ref, contig_name, labels):
h=height,
label=indexes)

def get_all_arrows(group_ref, labels):
repeatset = set()
def get_all_arrows(group_ref: str, labels: bool) -> Iterable[Arrow]:
repeatset: Set[str] = set()
for parent_name in sorted_roots:
yield from get_arrows(repeatset, group_ref, parent_name, labels)

Expand Down Expand Up @@ -966,8 +966,8 @@ def get_all_arrows(group_ref, labels):
# Contigs #
###########

repeatset1 = set()
repeatset2 = set()
repeatset1: Set[str] = set()
repeatset2: Set[str] = set()
for parent_name in sorted_roots:
arrows = list(get_arrows(repeatset1, group_ref, parent_name, labels=False))
if arrows:
Expand All @@ -992,7 +992,7 @@ def get_all_arrows(group_ref, labels):

contig = contig_map[contig_name]
(r_st, r_ei, f_r_st, f_r_ei) = get_contig_coordinates(contig)
name = name_mappings.get(contig.name, contig.name)
name = name_mappings.get(contig_name, contig_name)
figure.add(Arrow(r_st, r_ei, elevation=-20, h=1))
figure.add(Track(f_r_st, f_r_ei, label=name))

Expand Down Expand Up @@ -1021,7 +1021,7 @@ def get_all_arrows(group_ref, labels):
else:
colour = "red"

name = name_mappings.get(contig.name, contig.name)
name = name_mappings.get(contig_name, contig_name)
figure.add(Track(a_r_st, a_r_ei, color=colour, label=name))

###########
Expand All @@ -1042,7 +1042,7 @@ def get_all_arrows(group_ref, labels):
r_st = position_offset
r_ei = position_offset + len(contig.seq)
colour = "red"
name = name_mappings.get(contig.name, contig.name)
name = name_mappings.get(contig_name, contig_name)
figure.add(Track(r_st, r_ei, color=colour, label=name))

if not figure.elements:
Expand Down
2 changes: 1 addition & 1 deletion micall/utils/contig_stitcher_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,4 @@ class FinalCombine:

AlignmentEvent = Union[NoRef, ZeroHits, StrandConflict, HitNumber, ReverseComplement, Hit]
ModifyEvent = Union[LStrip, RStrip]
EventType = Union[Cut, ModifyEvent, Munge, AlignmentEvent, StitchCut, Overlap, NoOverlap, Stitch, Drop, IgnoreGap, SplitGap, Intro]
EventType = Union[Cut, ModifyEvent, Munge, Combine, AlignmentEvent, StitchCut, Overlap, NoOverlap, Stitch, Drop, IgnoreGap, SplitGap, Intro, FinalCombine]

0 comments on commit 3e6d20b

Please sign in to comment.