diff --git a/sc2ts/cli.py b/sc2ts/cli.py index 7c37a68..79af3b8 100644 --- a/sc2ts/cli.py +++ b/sc2ts/cli.py @@ -633,103 +633,6 @@ def find_previous_date_path(date, path_pattern): return path -@click.command() -@click.argument("dataset", type=click.Path(exists=True, dir_okay=False)) -@click.argument("ts", type=click.Path(exists=True, dir_okay=False)) -@click.argument("path_pattern") -@num_mismatches -@click.option( - "--num-threads", - default=0, - type=int, - help="Number of match threads (default to one)", -) -@click.option("--progress/--no-progress", default=True) -@click.option("-v", "--verbose", count=True) -@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False)) -def rematch_recombinants( - dataset, - ts, - path_pattern, - num_mismatches, - num_threads, - progress, - verbose, - log_file, -): - setup_logging(verbose, log_file) - ts = tszip.load(ts) - # This is a map of recombinant node to the samples involved in - # the original causal sample group. - recombinant_strains = sc2ts.get_recombinant_strains(ts) - logger.info( - f"Got {len(recombinant_strains)} recombinants and " - f"{sum(len(v) for v in recombinant_strains.values())} strains" - ) - - # Map recombinants to originating date - recombinant_to_path = {} - strain_to_recombinant = {} - all_strains = [] - for u, strains in recombinant_strains.items(): - date_added = ts.node(u).metadata["sc2ts"]["date_added"] - base_ts_path = find_previous_date_path(date_added, path_pattern) - recombinant_to_path[u] = base_ts_path - for strain in strains: - strain_to_recombinant[strain] = u - all_strains.append(strain) - - ds = sc2ts.Dataset(dataset) - progress_title = "Recomb" - samples = sc2ts.preprocess( - all_strains, - datset=ds, - show_progress=progress, - progress_title=progress_title, - keep_sites=ts.sites_position.astype(int), - num_workers=num_threads, - ) - - recombinant_to_samples = collections.defaultdict(list) - for sample in samples: - if sample.haplotype is None: - raise ValueError(f"No alignment stored for {sample.strain}") - recombinant = strain_to_recombinant[sample.strain] - recombinant_to_samples[recombinant].append(sample) - - work = [] - for recombinant, samples in recombinant_to_samples.items(): - for direction in ["forward", "reverse"]: - work.append( - MatchWork( - recombinant_to_path[recombinant], - samples, - num_mismatches=num_mismatches, - direction=direction, - ) - ) - - bar = sc2ts.get_progress(None, progress_title, "HMM", progress, total=len(work)) - - def output(hmm_runs): - bar.update() - for run in hmm_runs: - print(run.asjson()) - - results = [] - if num_threads == 0: - for w in work: - hmm_runs = _match_worker(w) - output(hmm_runs) - else: - with cf.ProcessPoolExecutor(num_threads) as executor: - futures = [executor.submit(_match_worker, w) for w in work] - for future in cf.as_completed(futures): - hmm_runs = future.result() - output(hmm_runs) - bar.close() - - @click.version_option(core.__version__) @click.group() def cli(): @@ -747,5 +650,4 @@ def cli(): cli.add_command(infer) cli.add_command(validate) cli.add_command(_match) -cli.add_command(rematch_recombinants) cli.add_command(tally_lineages) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index f303261..488b127 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -366,7 +366,6 @@ class Sample: hmm_match: HmmMatch = None breakpoint_intervals: List = dataclasses.field(default_factory=list) flags: int = tskit.NODE_IS_SAMPLE - hmm_reruns: Dict = dataclasses.field(default_factory=dict) @property def is_recombinant(self): @@ -382,46 +381,7 @@ def num_deletion_sites(self): def summary(self): hmm_match = "No match" if self.hmm_match is None else self.hmm_match.summary() - s = f"{self.strain} {self.date} {self.pango} {hmm_match}" - for name, hmm_match in self.hmm_reruns.items(): - s += f"; {name}: {hmm_match.summary()}" - return s - - -# TODO not clear if we still need this as mirroring is done differently now. -# Remove if we don't have any issues with running the HMM in reverse -def pad_sites(ts): - """ - Fill in missing sites with the reference state. - """ - ref = core.get_reference_sequence() - missing_sites = set(np.arange(1, len(ref))) - missing_sites -= set(ts.sites_position.astype(int)) - tables = ts.dump_tables() - for pos in missing_sites: - tables.sites.add_row(pos, ref[pos]) - tables.sort() - return tables.tree_sequence() - - -# TODO remove this -def match_recombinants( - samples, base_ts, num_mismatches, show_progress=False, num_threads=None -): - for hmm_pass in ["forward", "reverse", "no_recombination"]: - logger.info(f"Running {hmm_pass} pass for {len(samples)} recombinants") - match_tsinfer( - samples=samples, - ts=base_ts, - num_mismatches=1000 if hmm_pass == "no_recombination" else num_mismatches, - mismatch_threshold=100, - num_threads=num_threads, - show_progress=show_progress, - mirror_coordinates=hmm_pass == "reverse", - ) - - for sample in samples: - sample.hmm_reruns[hmm_pass] = sample.hmm_match + return f"{self.strain} {self.date} {self.pango} {hmm_match}" def match_samples( @@ -838,7 +798,6 @@ def update_top_level_metadata(ts, date, retro_groups, samples): def add_sample_to_tables(sample, tables, group_id=None): sc2ts_md = { "hmm_match": sample.hmm_match.asdict(), - "hmm_reruns": {k: m.asdict() for k, m in sample.hmm_reruns.items()}, "alignment_composition": dict(sample.alignment_composition), "num_missing_sites": sample.num_missing_sites, } diff --git a/tests/test_inference.py b/tests/test_inference.py index 316ae88..f7d8aab 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1022,7 +1022,6 @@ def test_recombinant_example_1(self, fx_ts_map, fx_recombinant_example_1): {"left": bp, "parent": right_parent, "right": 29904}, ], } - assert smd["hmm_reruns"] == {} sample = ts.node(ts.samples()[-1]) smd = sample.metadata["sc2ts"] @@ -1036,7 +1035,6 @@ def test_recombinant_example_1(self, fx_ts_map, fx_recombinant_example_1): {"left": bp, "parent": right_parent, "right": 29904}, ], } - assert smd["hmm_reruns"] == {} recomb_node = ts.node(ts.num_nodes - 1) assert recomb_node.flags == sc2ts.NODE_IS_RECOMBINANT @@ -1078,7 +1076,7 @@ def test_recombinant_example_2(self, fx_ts_map, fx_recombinant_example_2): date = "2020-03-01" rts = fx_recombinant_example_2 samples_strain = rts.metadata["sc2ts"]["samples_strain"] - assert samples_strain[-3:] == ["left", "right", "recombinant"] + assert samples_strain[-3:] == ["left", "right", "recombinant_114:29825"] sample = rts.node(rts.samples()[-1]) smd = sample.metadata["sc2ts"] @@ -1090,8 +1088,6 @@ def test_recombinant_example_2(self, fx_ts_map, fx_recombinant_example_2): ], } - assert smd["hmm_reruns"] == {} - def test_all_As(self, tmp_path, fx_ts_map, fx_dataset): # Same as the recombinant_example_1() function above # Just to get something that looks like an alignment easily @@ -1310,71 +1306,3 @@ def test_example_3(self, fx_recombinant_example_3): m = s.hmm_match assert m.parents == [53, 54, 55] assert m.breakpoints == [0, 114, 15010, 29904] - - -class TestMatchRecombinants: - def test_example_1(self, fx_ts_map): - ts, s = recombinant_example_1(fx_ts_map) - - sc2ts.match_recombinants( - samples=[s], - base_ts=ts, - num_mismatches=2, - num_threads=0, - ) - left_parent = 31 - right_parent = 46 - interval_right = 11083 - - m = s.hmm_reruns["forward"] - assert len(m.mutations) == 0 - assert len(m.path) == 2 - assert m.path[0].parent == left_parent - assert m.path[0].left == 0 - assert m.path[0].right == interval_right - assert m.path[1].parent == right_parent - assert m.path[1].left == interval_right - assert m.path[1].right == ts.sequence_length - - interval_left = 3788 - m = s.hmm_reruns["reverse"] - assert len(m.mutations) == 0 - assert len(m.path) == 2 - assert m.path[0].left == 0 - assert m.path[0].right == interval_left - assert m.path[0].parent == left_parent - assert m.path[1].parent == right_parent - assert m.path[1].left == interval_left - assert m.path[1].right == ts.sequence_length - - m = s.hmm_reruns["no_recombination"] - # It seems that we can choose either the left or right parent - # arbitrarily :shrug: - assert len(m.mutations) == 3 - assert m.mutation_summary() in [ - "[A871G, A3027G, C3787T]", - "[T11083G, C15324T, C29303T]", - ] - assert len(m.path) == 1 - assert m.path[0].parent in [left_parent, right_parent] - assert m.path[0].left == 0 - assert m.path[0].right == ts.sequence_length - - assert "no_recombination" in s.summary() - - def test_all_As(self, fx_ts_map): - ts = fx_ts_map["2020-02-13"] - h = np.zeros(ts.num_sites, dtype=np.int8) - s = sc2ts.Sample("zerotype", "2020-02-14", haplotype=h) - - sc2ts.match_recombinants( - samples=[s], - base_ts=ts, - num_mismatches=3, - num_threads=0, - ) - assert len(s.hmm_reruns) == 3 - num_mutations = [] - for hmm_match in s.hmm_reruns.values(): - assert len(hmm_match.path) == 1 - assert len(hmm_match.mutations) == 20943