From 68a5b82739f330153c9d1a60a722c43429544097 Mon Sep 17 00:00:00 2001 From: Vitaliy Mysak Date: Tue, 23 Jan 2024 15:28:33 -0800 Subject: [PATCH] Contig stitcher: introduce a proper context for the name generator --- micall/core/contig_stitcher.py | 38 ++++++++++++++++++++-------------- micall/core/denovo.py | 35 +++++++++++++++++-------------- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/micall/core/contig_stitcher.py b/micall/core/contig_stitcher.py index 61e0c5742..70ca50e37 100644 --- a/micall/core/contig_stitcher.py +++ b/micall/core/contig_stitcher.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional, Tuple, List, Dict, Union, Literal, TypeVar +from typing import Iterable, Optional, Tuple, List, Dict, Union, Literal, TypeVar, Callable from collections import deque, defaultdict from dataclasses import dataclass, replace from math import ceil, floor @@ -9,19 +9,29 @@ from queue import LifoQueue from Bio import Seq import logging +from contextvars import ContextVar, Context from micall.utils.cigar_tools import Cigar, connect_cigar_hits, CigarHit from micall.utils.consensus_aligner import CigarActions - +T = TypeVar("T") logger = logging.getLogger(__name__) -name_generator_state = 0 -def generate_new_name(): - global name_generator_state - name_generator_state += 1 - return f"c{name_generator_state}" +class StitcherContext: + def __init__(self): + self.name_generator_state: int = 0 + + def generate_new_name(self): + self.name_generator_state += 1 + return f"c{self.name_generator_state}" + + +context: ContextVar[StitcherContext] = ContextVar("StitcherContext", default=StitcherContext()) + + +def with_fresh_context(body: Callable[[StitcherContext], T]) -> T: + return Context().run(lambda: body(context.get())) @dataclass(frozen=True) @@ -41,8 +51,8 @@ def cut_query(self, cut_point: float) -> Tuple['GenotypedContig', 'GenotypedCont """ Cuts query sequence in two parts with cut_point between them. """ cut_point = max(0, cut_point) - left = replace(self, name=generate_new_name(), seq=self.seq[:ceil(cut_point)]) - right = replace(self, name=generate_new_name(), seq=self.seq[ceil(cut_point):]) + left = replace(self, name=context.get().generate_new_name(), seq=self.seq[:ceil(cut_point)]) + right = replace(self, name=context.get().generate_new_name(), seq=self.seq[ceil(cut_point):]) return (left, right) @@ -68,8 +78,8 @@ def cut_reference(self, cut_point: float) -> Tuple['AlignedContig', 'AlignedCont """ Cuts this alignment in two parts with cut_point between them. """ alignment_left, alignment_right = self.alignment.cut_reference(cut_point) - left = replace(self, name=generate_new_name(), alignment=alignment_left) - right = replace(self, name=generate_new_name(), alignment=alignment_right) + left = replace(self, name=context.get().generate_new_name(), alignment=alignment_left) + right = replace(self, name=context.get().generate_new_name(), alignment=alignment_right) logger.debug("Created contigs %r at %s and %r at %s by cutting %r.", left.name, left.alignment, right.name, right.alignment, self.name, @@ -134,7 +144,7 @@ def munge(self, other: 'AlignedContig') -> 'AlignedContig': match_fraction = min(self.match_fraction, other.match_fraction) ref_name = max([self, other], key=lambda x: x.alignment.ref_length).ref_name query = GenotypedContig(seq=self.seq + other.seq, - name=generate_new_name(), + name=context.get().generate_new_name(), ref_name=ref_name, group_ref=self.group_ref, ref_seq=self.ref_seq, @@ -156,8 +166,6 @@ def munge(self, other: 'AlignedContig') -> 'AlignedContig': return ret -T = TypeVar("T") - def sliding_window(sequence: Iterable[T]) -> Iterable[Tuple[Optional[T], T, Optional[T]]]: """ Generate a three-element sliding window of a sequence. @@ -247,7 +255,7 @@ def align_to_reference(contig: GenotypedContig) -> Iterable[GenotypedContig]: contig = new_contig for i, single_hit in enumerate(connected): - query = replace(contig, name=generate_new_name()) + query = replace(contig, name=context.get().generate_new_name()) part = AlignedContig.make(query, single_hit, strand) logger.info("Part %r of contig %r aligned as %r at [%s, %s]->[%s, %s]%s.", diff --git a/micall/core/denovo.py b/micall/core/denovo.py index 6cbdf956b..fbdb654b3 100644 --- a/micall/core/denovo.py +++ b/micall/core/denovo.py @@ -19,7 +19,7 @@ from Bio.SeqRecord import SeqRecord from micall.core.project_config import ProjectConfig -from micall.core.contig_stitcher import GenotypedContig, stitch_consensus, logger as stitcher_logger +from micall.core.contig_stitcher import GenotypedContig, stitch_consensus, logger as stitcher_logger, with_fresh_context from micall.core.plot_contigs import plot_stitcher_coverage from micall.utils.structured_logger import add_structured_handler @@ -84,26 +84,29 @@ def write_contig_refs(contigs_fasta_path, contigs_fasta.write(f">{contig_name}\n{row['contig']}\n") group_refs = {} - logger = logging.getLogger("micall.core.contig_stitcher") - handler = add_structured_handler(logger) + def run_stitcher(ctx): + logger = logging.getLogger("micall.core.contig_stitcher") + handler = add_structured_handler(logger) - genotypes = genotype(contigs_fasta_path, - blast_csv=blast_csv, - group_refs=group_refs) + genotypes = genotype(contigs_fasta_path, + blast_csv=blast_csv, + group_refs=group_refs) - contigs = list(read_assembled_contigs(group_refs, genotypes, contigs_fasta_path)) - contigs = list(stitch_consensus(contigs)) + contigs = list(read_assembled_contigs(group_refs, genotypes, contigs_fasta_path)) + contigs = list(stitch_consensus(contigs)) - for contig in contigs: - writer.writerow(dict(ref=contig.ref_name, - match=contig.match_fraction, - group_ref=contig.group_ref, - contig=contig.seq)) + for contig in contigs: + writer.writerow(dict(ref=contig.ref_name, + match=contig.match_fraction, + group_ref=contig.group_ref, + contig=contig.seq)) - if stitcher_logger.level <= logging.DEBUG and stitcher_plot_path is not None: - plot_stitcher_coverage(handler.logs, stitcher_plot_path) + if stitcher_logger.level <= logging.DEBUG and stitcher_plot_path is not None: + plot_stitcher_coverage(handler.logs, stitcher_plot_path) - return len(contigs) + return len(contigs) + + return with_fresh_context(run_stitcher) def genotype(fasta, db=DEFAULT_DATABASE, blast_csv=None, group_refs=None):