Skip to content

Commit

Permalink
Contig stitcher: use context for logs handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Donaim committed Jan 24, 2024
1 parent 68a5b82 commit c5997ac
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 101 deletions.
113 changes: 53 additions & 60 deletions micall/core/contig_stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,33 @@

from micall.utils.cigar_tools import Cigar, connect_cigar_hits, CigarHit
from micall.utils.consensus_aligner import CigarActions
import micall.utils.contig_stitcher_events as events

T = TypeVar("T")
logger = logging.getLogger(__name__)


class StitcherContext:
def __init__(self):
self.name_generator_state: int = 0
self.events: List[events.EventType] = []

def generate_new_name(self):
self.name_generator_state += 1
return f"c{self.name_generator_state}"

def emit(self, event: events.EventType):
self.events.append(event)


context: ContextVar[StitcherContext] = ContextVar("StitcherContext", default=StitcherContext())


def with_fresh_context(body: Callable[[StitcherContext], T]) -> T:
return Context().run(lambda: body(context.get()))
def wrapper():
ctx = StitcherContext()
context.set(ctx)
return body(ctx)
return Context().run(wrapper)


@dataclass(frozen=True)
Expand Down Expand Up @@ -82,9 +90,8 @@ def cut_reference(self, cut_point: float) -> Tuple['AlignedContig', 'AlignedCont
right = replace(self, name=context.get().generate_new_name(), alignment=alignment_right)

logger.debug("Created contigs %r at %s and %r at %s by cutting %r.",
left.name, left.alignment, right.name, right.alignment, self.name,
extra={"action": "cut", "original": self,
"left": left, "right": right})
left.name, left.alignment, right.name, right.alignment, self.name)
context.get().emit(events.Cut(self, left, right))

return (left, right)

Expand All @@ -101,9 +108,8 @@ def lstrip_query(self) -> 'AlignedContig':
result = AlignedContig.make(query, alignment, self.strand)
logger.debug("Doing lstrip of %r resulted in %r, so %s (len %s) became %s (len %s)",
self.name, result.name, self.alignment,
len(self.seq), result.alignment, len(result.seq),
extra={"action": "modify", "type": "lstrip",
"original": self, "result": result})
len(self.seq), result.alignment, len(result.seq))
context.get().emit(events.LStrip(self, result))
return result


Expand All @@ -118,9 +124,8 @@ def rstrip_query(self) -> 'AlignedContig':
result = AlignedContig.make(query, alignment, self.strand)
logger.debug("Doing rstrip of %r resulted in %r, so %s (len %s) became %s (len %s)",
self.name, result.name, self.alignment,
len(self.seq), result.alignment, len(result.seq),
extra={"action": "modify", "type": "rstrip",
"original": self, "result": result})
len(self.seq), result.alignment, len(result.seq))
context.get().emit(events.RStrip(self, result))
return result


Expand Down Expand Up @@ -160,9 +165,8 @@ def munge(self, other: 'AlignedContig') -> 'AlignedContig':
assert self.strand == other.strand
ret = AlignedContig.make(query=query, alignment=alignment, strand=self.strand)
logger.debug("Munged contigs %r at %s with %r at %s resulting in %r at %s.",
self.name, self.alignment, other.name, other.alignment,
ret.name, ret.alignment, extra={"action": "munge", "left": self,
"right": other, "result": ret})
self.name, self.alignment, other.name, other.alignment, ret.name, ret.alignment)
context.get().emit(events.Munge(self, other, ret))
return ret


Expand Down Expand Up @@ -199,8 +203,8 @@ def combine_contigs(parts: List[AlignedContig]) -> AlignedContig:
ret = reduce(AlignedContig.munge, stripped_parts)
logger.debug("Created a frankenstein %r at %s (len %s) from %s.",
ret.name, ret.alignment, len(ret.seq),
[f"{x.name!r} at {x.alignment} (len {len(x.seq)})" for x in stripped_parts],
extra={"action": "combine", "contigs": stripped_parts, "result": ret})
[f"{x.name!r} at {x.alignment} (len {len(x.seq)})" for x in stripped_parts])
context.get().emit(events.Combine(stripped_parts, ret))
return ret


Expand All @@ -213,8 +217,8 @@ def align_to_reference(contig: GenotypedContig) -> Iterable[GenotypedContig]:
"""

if contig.ref_seq is None:
logger.info("Contig %r not aligned - no reference.", contig.name,
extra={"action": "alignment", "type": "noref", "contig": contig})
logger.info("Contig %r not aligned - no reference.", contig.name)
context.get().emit(events.NoRef(contig))
yield contig
return

Expand All @@ -229,29 +233,27 @@ def align_to_reference(contig: GenotypedContig) -> Iterable[GenotypedContig]:
connected = connect_cigar_hits(list(map(lambda p: p[0], hits_array))) if hits_array else []

if not connected:
logger.info("Contig %r not aligned - backend's choice.", contig.name,
extra={"action": "alignment", "type": "zerohits", "contig": contig})
logger.info("Contig %r not aligned - backend's choice.", contig.name)
context.get().emit(events.ZeroHits(contig))
yield contig
return

if len(set(map(lambda p: p[1], hits_array))) > 1:
logger.info("Discarding contig %r because it aligned both in forward and reverse sense.", contig.name,
extra={"action": "alignment", "type": "strandconflict", "contig": contig})
logger.info("Discarding contig %r because it aligned both in forward and reverse sense.", contig.name)
context.get().emit(events.StrandConflict(contig))
yield contig
return

logger.info("Contig %r produced %s aligner hits. After connecting them, the number became %s.",
contig.name, len(hits_array), len(connected),
extra={"action": "alignment", "type": "hitnumber", "contig": contig,
"initial": hits_array, "connected": connected})
contig.name, len(hits_array), len(connected))
context.get().emit(events.HitNumber(contig, hits_array, connected))

strand = hits_array[0][1]
if strand == "reverse":
rc = str(Seq(contig.seq).reverse_complement())
new_contig = replace(contig, seq=rc)
logger.info("Reverse complemented contig %r.", contig.name,
extra={"action": "alignment", "type": "reversecomplement",
"contig": contig, "result": new_contig})
logger.info("Reverse complemented contig %r.", contig.name)
context.get().emit(events.ReverseComplement(contig, new_contig))
contig = new_contig

for i, single_hit in enumerate(connected):
Expand All @@ -261,12 +263,10 @@ def align_to_reference(contig: GenotypedContig) -> Iterable[GenotypedContig]:
logger.info("Part %r of contig %r aligned as %r at [%s, %s]->[%s, %s]%s.",
i, contig.name,part.name,part.alignment.q_st,
part.alignment.q_ei,part.alignment.r_st,part.alignment.r_ei,
" (rev)" if strand == "reverse" else "",
extra={"action": "alignment", "type": "hit",
"contig": contig, "part":part, "i": i})
" (rev)" if strand == "reverse" else "")
logger.debug("Part %r of contig %r aligned as %r at %s%s.", i, contig.name,
part.name,part.alignment, " (rev)" if strand == "reverse" else "")

context.get().emit(events.Hit(contig, part, i))
yield part


Expand Down Expand Up @@ -455,10 +455,8 @@ def stitch_2_contigs(left, right):
left.name, left.alignment, len(left.seq),
right.name, right.alignment, len(right.seq),
left_overlap.name, left_overlap.alignment, len(left_overlap.seq),
right_overlap.name, right_overlap.alignment, len(right_overlap.seq),
extra={"action": "stitchcut", "left": left, "right": right,
"left_overlap": left_overlap, "right_overlap": right_overlap,
"left_remainder": left_remainder, "right_remainder": right_remainder})
right_overlap.name, right_overlap.alignment, len(right_overlap.seq))
context.get().emit(events.StitchCut(left, right, left_overlap, right_overlap, left_remainder, right_remainder))

# Align overlapping parts, then recombine based on concordance.
aligned_left, aligned_right = align_queries(left_overlap.seq, right_overlap.seq)
Expand All @@ -475,14 +473,11 @@ def stitch_2_contigs(left, right):
logger.debug("Created overlap contigs %r at %s and %r at %s based on parts of %r and %r, with avg. concordance %s%%, cut point at %s%%, and full concordance [%s].",
left_overlap_take.name, left_overlap.alignment, right_overlap_take.name, right_overlap_take.alignment,
left.name, right.name, round(average_concordance * 100),
round(cut_point_location_scaled * 100), concordance_str,
extra={"action": "overlap", "left": left, "right": right,
"left_remainder": left_remainder, "right_remainder": right_remainder,
"left_overlap": left_overlap, "right_overlap": right_overlap,
"left_take": left_overlap_take, "right_take": right_overlap_take,
"concordance": concordance, "avg": average_concordance,
"cut_point": max_concordance_index,
"cut_point_scaled": cut_point_location_scaled})
round(cut_point_location_scaled * 100), concordance_str)
context.get().emit(events.Overlap(left, right, left_overlap, right_overlap,
left_remainder, right_remainder, left_overlap_take,
right_overlap_take, concordance, average_concordance,
max_concordance_index, cut_point_location_scaled))

return combine_contigs([left_remainder, left_overlap_take, right_overlap_take, right_remainder])

Expand All @@ -501,8 +496,8 @@ def combine_overlaps(contigs: List[AlignedContig]) -> Iterable[AlignedContig]:
# Find overlap. If there isn't one - we are done with the current contig.
overlapping_contig = find_overlapping_contig(current, contigs)
if not overlapping_contig:
logger.info("Nothing overlaps with %r.", current.name,
extra={"action": "nooverlap", "contig": current})
logger.info("Nothing overlaps with %r.", current.name)
context.get().emit(events.NoOverlap(current))
yield current
continue

Expand All @@ -514,12 +509,11 @@ def combine_overlaps(contigs: List[AlignedContig]) -> Iterable[AlignedContig]:
logger.info("Stitching %r with %r results in %r at [%s,%s]->[%s,%s].",
current.name, overlapping_contig.name,
new_contig.name, new_contig.alignment.q_st, new_contig.alignment.q_ei,
new_contig.alignment.r_st, new_contig.alignment.r_ei,
extra={"action": "stitch", "result": new_contig,
"left": current, "right": overlapping_contig})
new_contig.alignment.r_st, new_contig.alignment.r_ei)
logger.debug("Stitching %r with %r results in %r at %s (len %s).",
current.name, overlapping_contig.name,
new_contig.name, new_contig.alignment, len(new_contig.seq))
context.get().emit(events.Stitch(current, overlapping_contig, new_contig))


def merge_intervals(intervals: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
Expand Down Expand Up @@ -588,8 +582,8 @@ def drop_completely_covered(contigs: List[AlignedContig]) -> List[AlignedContig]
if covered:
contigs.remove(covered)
logger.info("Droped contig %r as it is completely covered by these contigs: %s.",
covered.name, ", ".join(repr(x.name) for x in covering),
extra={"action": "drop", "contig": covered, "covering": covering})
covered.name, ", ".join(repr(x.name) for x in covering))
context.get().emit(events.Drop(covered, covering))
else:
break

Expand Down Expand Up @@ -624,8 +618,8 @@ def try_split(contig):
# overlaps around them.
# And we are likely to lose quality with every stitching operation.
# By skipping we assert that this gap is aligner's fault.
logger.debug("Ignored insignificant gap of %r, %s.", contig.name, gap,
extra={"action": "ignoregap", "contig": contig, "gap": gap})
logger.debug("Ignored insignificant gap of %r, %s.", contig.name, gap)
context.get().emit(events.IgnoreGap(contig, gap))
continue

if covered(contig, gap):
Expand All @@ -646,9 +640,8 @@ def try_split(contig):
left_part.name, left_part.alignment.q_st, left_part.alignment.q_ei,
left_part.alignment.r_st, left_part.alignment.r_ei,
right_part.name, right_part.alignment.q_st, right_part.alignment.q_ei,
right_part.alignment.r_st, right_part.alignment.r_ei,
extra={"action": "splitgap", "contig": contig,
"gap": gap, "left": left_part, "right": right_part})
right_part.alignment.r_st, right_part.alignment.r_ei)
context.get().emit(events.SplitGap(contig, gap, left_part, right_part))
return

process_queue: LifoQueue = LifoQueue()
Expand All @@ -665,11 +658,11 @@ def stitch_contigs(contigs: Iterable[GenotypedContig]) -> Iterable[GenotypedCont
contigs = list(contigs)
for contig in contigs:
logger.info("Introduced contig %r of ref %r, group_ref %r, and length %s.",
contig.name, contig.ref_name, contig.group_ref, len(contig.seq),
extra={"action": "intro", "contig": contig})
contig.name, contig.ref_name, contig.group_ref, len(contig.seq))
logger.debug("Introduced contig %r (seq = %s) of ref %r, group_ref %r (seq = %s), and length %s.",
contig.name, contig.seq, contig.ref_name,
contig.group_ref, contig.ref_seq, len(contig.seq))
context.get().emit(events.Intro(contig))

maybe_aligned = list(align_all_to_reference(contigs))

Expand Down Expand Up @@ -698,8 +691,8 @@ def combine(group_ref):
contigs = sorted(consensus_parts[group_ref], key=lambda x: x.alignment.r_st)
result = combine_contigs(contigs)
logger.debug("Combining these contigs for final output for %r: %s.",
group_ref, [f"{x.name!r} at {x.alignment} (len {len(x.seq)})" for x in contigs],
extra={"action": "finalcombine", "contigs": contigs, "result": result})
group_ref, [f"{x.name!r} at {x.alignment} (len {len(x.seq)})" for x in contigs])
context.get().emit(events.FinalCombine(contigs, result))
return result

yield from map(combine, consensus_parts)
Expand Down
5 changes: 1 addition & 4 deletions micall/core/denovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ def write_contig_refs(contigs_fasta_path,
group_refs = {}

def run_stitcher(ctx):
logger = logging.getLogger("micall.core.contig_stitcher")
handler = add_structured_handler(logger)

genotypes = genotype(contigs_fasta_path,
blast_csv=blast_csv,
group_refs=group_refs)
Expand All @@ -102,7 +99,7 @@ def run_stitcher(ctx):
contig=contig.seq))

if stitcher_logger.level <= logging.DEBUG and stitcher_plot_path is not None:
plot_stitcher_coverage(handler.logs, stitcher_plot_path)
plot_stitcher_coverage(ctx.events, stitcher_plot_path)

return len(contigs)

Expand Down
Loading

0 comments on commit c5997ac

Please sign in to comment.