Skip to content

Commit

Permalink
Contig stitcher: introduce a proper context for the name generator
Browse files Browse the repository at this point in the history
  • Loading branch information
Donaim committed Jan 23, 2024
1 parent 6caf9ee commit 68a5b82
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 31 deletions.
38 changes: 23 additions & 15 deletions micall/core/contig_stitcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Optional, Tuple, List, Dict, Union, Literal, TypeVar
from typing import Iterable, Optional, Tuple, List, Dict, Union, Literal, TypeVar, Callable
from collections import deque, defaultdict
from dataclasses import dataclass, replace
from math import ceil, floor
Expand All @@ -9,19 +9,29 @@
from queue import LifoQueue
from Bio import Seq
import logging
from contextvars import ContextVar, Context

from micall.utils.cigar_tools import Cigar, connect_cigar_hits, CigarHit
from micall.utils.consensus_aligner import CigarActions


T = TypeVar("T")
logger = logging.getLogger(__name__)


name_generator_state = 0
def generate_new_name():
global name_generator_state
name_generator_state += 1
return f"c{name_generator_state}"
class StitcherContext:
def __init__(self):
self.name_generator_state: int = 0

def generate_new_name(self):
self.name_generator_state += 1
return f"c{self.name_generator_state}"


context: ContextVar[StitcherContext] = ContextVar("StitcherContext", default=StitcherContext())


def with_fresh_context(body: Callable[[StitcherContext], T]) -> T:
return Context().run(lambda: body(context.get()))


@dataclass(frozen=True)
Expand All @@ -41,8 +51,8 @@ def cut_query(self, cut_point: float) -> Tuple['GenotypedContig', 'GenotypedCont
""" Cuts query sequence in two parts with cut_point between them. """

cut_point = max(0, cut_point)
left = replace(self, name=generate_new_name(), seq=self.seq[:ceil(cut_point)])
right = replace(self, name=generate_new_name(), seq=self.seq[ceil(cut_point):])
left = replace(self, name=context.get().generate_new_name(), seq=self.seq[:ceil(cut_point)])
right = replace(self, name=context.get().generate_new_name(), seq=self.seq[ceil(cut_point):])
return (left, right)


Expand All @@ -68,8 +78,8 @@ def cut_reference(self, cut_point: float) -> Tuple['AlignedContig', 'AlignedCont
""" Cuts this alignment in two parts with cut_point between them. """

alignment_left, alignment_right = self.alignment.cut_reference(cut_point)
left = replace(self, name=generate_new_name(), alignment=alignment_left)
right = replace(self, name=generate_new_name(), alignment=alignment_right)
left = replace(self, name=context.get().generate_new_name(), alignment=alignment_left)
right = replace(self, name=context.get().generate_new_name(), alignment=alignment_right)

logger.debug("Created contigs %r at %s and %r at %s by cutting %r.",
left.name, left.alignment, right.name, right.alignment, self.name,
Expand Down Expand Up @@ -134,7 +144,7 @@ def munge(self, other: 'AlignedContig') -> 'AlignedContig':
match_fraction = min(self.match_fraction, other.match_fraction)
ref_name = max([self, other], key=lambda x: x.alignment.ref_length).ref_name
query = GenotypedContig(seq=self.seq + other.seq,
name=generate_new_name(),
name=context.get().generate_new_name(),
ref_name=ref_name,
group_ref=self.group_ref,
ref_seq=self.ref_seq,
Expand All @@ -156,8 +166,6 @@ def munge(self, other: 'AlignedContig') -> 'AlignedContig':
return ret


T = TypeVar("T")

def sliding_window(sequence: Iterable[T]) -> Iterable[Tuple[Optional[T], T, Optional[T]]]:
"""
Generate a three-element sliding window of a sequence.
Expand Down Expand Up @@ -247,7 +255,7 @@ def align_to_reference(contig: GenotypedContig) -> Iterable[GenotypedContig]:
contig = new_contig

Check warning on line 255 in micall/core/contig_stitcher.py

View check run for this annotation

Codecov / codecov/patch

micall/core/contig_stitcher.py#L255

Added line #L255 was not covered by tests

for i, single_hit in enumerate(connected):
query = replace(contig, name=generate_new_name())
query = replace(contig, name=context.get().generate_new_name())
part = AlignedContig.make(query, single_hit, strand)

logger.info("Part %r of contig %r aligned as %r at [%s, %s]->[%s, %s]%s.",
Expand Down
35 changes: 19 additions & 16 deletions micall/core/denovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from Bio.SeqRecord import SeqRecord

from micall.core.project_config import ProjectConfig
from micall.core.contig_stitcher import GenotypedContig, stitch_consensus, logger as stitcher_logger
from micall.core.contig_stitcher import GenotypedContig, stitch_consensus, logger as stitcher_logger, with_fresh_context
from micall.core.plot_contigs import plot_stitcher_coverage
from micall.utils.structured_logger import add_structured_handler

Expand Down Expand Up @@ -84,26 +84,29 @@ def write_contig_refs(contigs_fasta_path,
contigs_fasta.write(f">{contig_name}\n{row['contig']}\n")
group_refs = {}

logger = logging.getLogger("micall.core.contig_stitcher")
handler = add_structured_handler(logger)
def run_stitcher(ctx):
logger = logging.getLogger("micall.core.contig_stitcher")
handler = add_structured_handler(logger)

genotypes = genotype(contigs_fasta_path,
blast_csv=blast_csv,
group_refs=group_refs)
genotypes = genotype(contigs_fasta_path,
blast_csv=blast_csv,
group_refs=group_refs)

contigs = list(read_assembled_contigs(group_refs, genotypes, contigs_fasta_path))
contigs = list(stitch_consensus(contigs))
contigs = list(read_assembled_contigs(group_refs, genotypes, contigs_fasta_path))
contigs = list(stitch_consensus(contigs))

for contig in contigs:
writer.writerow(dict(ref=contig.ref_name,
match=contig.match_fraction,
group_ref=contig.group_ref,
contig=contig.seq))
for contig in contigs:
writer.writerow(dict(ref=contig.ref_name,
match=contig.match_fraction,
group_ref=contig.group_ref,
contig=contig.seq))

if stitcher_logger.level <= logging.DEBUG and stitcher_plot_path is not None:
plot_stitcher_coverage(handler.logs, stitcher_plot_path)
if stitcher_logger.level <= logging.DEBUG and stitcher_plot_path is not None:
plot_stitcher_coverage(handler.logs, stitcher_plot_path)

return len(contigs)
return len(contigs)

return with_fresh_context(run_stitcher)


def genotype(fasta, db=DEFAULT_DATABASE, blast_csv=None, group_refs=None):
Expand Down

0 comments on commit 68a5b82

Please sign in to comment.