diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 65a74f9..f3352be 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -364,8 +364,9 @@ class Sample: alignment_composition: Dict = None haplotype: List = None hmm_match: HmmMatch = None - hmm_reruns: Dict = dataclasses.field(default_factory=dict) + breakpoint_intervals: List = dataclasses.field(default_factory=list) flags: int = tskit.NODE_IS_SAMPLE + hmm_reruns: Dict = dataclasses.field(default_factory=dict) @property def is_recombinant(self): @@ -725,8 +726,9 @@ def _extend( num_threads=num_threads, memory_limit=memory_limit, ) - characterise_match_mutations(base_ts, samples) + characterise_recombinants(base_ts, samples) + for sample in unconditional_include_samples: # We want this sample to included unconditionally, so we set the # hmm cost to 0 < hmm_cost < hmm_cost_threshold. We use 0.5 @@ -840,6 +842,8 @@ def add_sample_to_tables(sample, tables, group_id=None): "alignment_composition": dict(sample.alignment_composition), "num_missing_sites": sample.num_missing_sites, } + if sample.is_recombinant: + sc2ts_md["breakpoint_intervals"] = sample.breakpoint_intervals if group_id is not None: sc2ts_md["group_id"] = group_id metadata = {**sample.metadata, "sc2ts": sc2ts_md} @@ -1681,6 +1685,33 @@ def get_closest_mutation(node, site_id): logger.debug(f"Characterised {num_mutations} mutations") +def characterise_recombinants(ts, samples): + """ + Update the metadata for any recombinants to add interval information to the metadata. + """ + recombinants = [s for s in samples if s.is_recombinant] + if len(recombinants) == 0: + return + logger.info(f"Characterising {len(recombinants)} recombinants") + + # NOTE: could make this more efficient by doing one call to genotype_matrix, + # but recombinants are rare so let's keep this simple + for s in recombinants: + parents = [seg.parent for seg in s.hmm_match.path] + # Can't have missing data here, so we're OK. + H = ts.genotype_matrix(samples=parents, isolated_as_missing=False).T + breakpoint_intervals = [] + for j in range(len(parents) - 1): + parents_differ = np.where(H[j] != H[j + 1])[0] + pos = ts.sites_position[parents_differ].astype(int) + right = s.hmm_match.path[j].right + right_index = np.searchsorted(pos, right) + assert pos[right_index] == right + left = pos[right_index - 1] + 1 + breakpoint_intervals.append((int(left), int(right))) + s.breakpoint_intervals = breakpoint_intervals + + def attach_tree( parent_ts, parent_tables, diff --git a/tests/conftest.py b/tests/conftest.py index 62157c8..dbe1c76 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -254,9 +254,11 @@ def recombinant_example_2(tmp_path, fx_ts_map, fx_dataset, ds_path): ts_path = tmp_path / "intermediate.ts" ts.dump(ts_path) - # Now run again with the recombinant of these two + # Now run again with the recombinant of these two, encoding the interval in the # name date = "2020-03-02" - ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", {"recombinant": a}, date=date) + left = start + 3 + 1 + right = end - 3 + 1 + ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", {f"recombinant_{left}:{right}": a}, date=date) rts = sc2ts.extend( dataset=ds.path, base_ts=ts_path, diff --git a/tests/test_inference.py b/tests/test_inference.py index b3e240c..ae077b9 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -80,7 +80,33 @@ def test_get_recombinant_strains_ex1(self, fx_recombinant_example_1): def test_get_recombinant_strains_ex2(self, fx_recombinant_example_2): d = sc2ts.get_recombinant_strains(fx_recombinant_example_2) - assert d == {56: ["recombinant"]} + assert d == {56: ["recombinant_114:29825"]} + + def test_recombinant_example_1(self, fx_recombinant_example_1): + ts = fx_recombinant_example_1 + samples_strain = ts.metadata["sc2ts"]["samples_strain"] + samples = ts.samples() + for s in ["recombinant_example_1_0", "recombinant_example_1_1"]: + u = samples[samples_strain.index(s)] + node = ts.node(u) + md = node.metadata["sc2ts"] + assert md["breakpoint_intervals"] == [[3788, 11083]] + assert md["hmm_match"]["path"] == [ + {"left": 0, "parent": 31, "right": 11083}, + {"left": 11083, "parent": 46, "right": 29904}, + ] + + def test_recombinant_example_2(self, fx_recombinant_example_2): + ts = fx_recombinant_example_2 + samples_strain = ts.metadata["sc2ts"]["samples_strain"] + u = ts.samples()[samples_strain.index("recombinant_114:29825")] + node = ts.node(u) + md = node.metadata["sc2ts"] + assert md["breakpoint_intervals"] == [[114, 29825]] + assert md["hmm_match"]["path"] == [ + {"left": 0, "parent": 53, "right": 29825}, + {"left": 29825, "parent": 54, "right": 29904}, + ] class TestSolveNumMismatches: @@ -1187,6 +1213,53 @@ def test_match_recombinant(self, fx_ts_map): assert m.path[1].right == ts.sequence_length +class TestCharacteriseRecombinants: + + def test_example_1(self, fx_ts_map): + ts, s = recombinant_example_1(fx_ts_map) + + interval_left = 3788 + interval_right = 11083 + left_parent = 31 + right_parent = 46 + + sc2ts.match_tsinfer( + samples=[s], + ts=ts, + num_mismatches=2, + mismatch_threshold=10, + ) + m = s.hmm_match + assert len(m.mutations) == 0 + assert len(m.path) == 2 + assert m.path[0].parent == left_parent + assert m.path[0].left == 0 + assert m.path[0].right == interval_right + assert m.path[1].parent == right_parent + assert m.path[1].left == interval_right + assert m.path[1].right == ts.sequence_length + + sc2ts.characterise_recombinants(ts, [s]) + assert s.breakpoint_intervals == [(interval_left, interval_right)] + + sc2ts.match_tsinfer( + samples=[s], + ts=ts, + num_mismatches=2, + mismatch_threshold=10, + mirror_coordinates=True, + ) + m = s.hmm_match + assert len(m.mutations) == 0 + assert len(m.path) == 2 + assert m.path[0].parent == left_parent + assert m.path[0].left == 0 + assert m.path[0].right == interval_left + assert m.path[1].parent == right_parent + assert m.path[1].left == interval_left + assert m.path[1].right == ts.sequence_length + + class TestMatchRecombinants: def test_example_1(self, fx_ts_map): ts, s = recombinant_example_1(fx_ts_map)