From b20da98f2b4851247655a860d4de7cfafd394877 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 13 Dec 2024 14:18:43 +0000 Subject: [PATCH 1/6] Fixup recombinant_example_2 --- tests/conftest.py | 74 +++++++++++++++-------------------------- tests/test_inference.py | 2 -- 2 files changed, 27 insertions(+), 49 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index faaa809..62157c8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -68,7 +68,6 @@ def fx_raw_viridian_metadata_tsv(): @pytest.fixture def fx_metadata_df(fx_metadata_tsv): - return read_metadata_df(fx_metadata_tsv) @@ -138,7 +137,7 @@ def fx_ts_map(tmp_path, fx_data_cache, fx_dataset, fx_match_db): # Load the ts from file to get the provenance data extra_kwargs = {} if date == dates[-1]: - # Force a bunch of retro groups in on the last day + # Force a bunch of retro groups in on the last day extra_kwargs = { "min_group_size": 1, "min_root_mutations": 0, @@ -150,7 +149,7 @@ def fx_ts_map(tmp_path, fx_data_cache, fx_dataset, fx_match_db): base_ts=cache_path, date=date, match_db=fx_match_db.path, - **extra_kwargs + **extra_kwargs, ) print( f"INFERRED {date} nodes={last_ts.num_nodes} mutations={last_ts.num_mutations}" @@ -166,25 +165,6 @@ def fx_ts_map(tmp_path, fx_data_cache, fx_dataset, fx_match_db): return d -def tmp_alignment_store(tmp_path, alignments): - path = tmp_path / "synthetic_alignments.db" - alignment_db = sc2ts.AlignmentStore(path, mode="rw") - alignment_db.append(alignments) - return alignment_db - - -def tmp_metadata_db(tmp_path, strains, date): - data = [] - for strain in strains: - data.append({"strain": strain, "date": date}) - df = pd.DataFrame(data) - csv_path = tmp_path / "metadata.csv" - df.to_csv(csv_path) - db_path = tmp_path / "metadata.db" - sc2ts.MetadataDb.import_csv(csv_path, db_path, sep=",") - return sc2ts.MetadataDb(db_path) - - def recombinant_alignments(dataset): """ Generate some recombinant alignments from existing haplotypes @@ -223,42 +203,40 @@ def recombinant_example_1(tmp_path, fx_ts_map, fx_dataset, ds_path): return ts -def recombinant_example_2(tmp_path, fx_ts_map, fx_alignment_store): +def recombinant_example_2(tmp_path, fx_ts_map, fx_dataset, ds_path): # Pick a distinct strain to be the root of our two new haplotypes added # on the first day. root_strain = "SRR11597116" - a = fx_alignment_store[root_strain] + a = fx_dataset.haplotypes[root_strain] base_ts = fx_ts_map["2020-02-13"] # This sequence has a bunch of Ns at the start, so we have to go inwards # from them to make sure we're not masking them out. - start = np.where(a != "N")[0][1] + 7 + start = np.where(a != -1)[0][1] + 7 left_a = a.copy() - left_a[start : start + 3] = "G" + left_a[start : start + 3] = 2 # "G" - end = np.where(a != "N")[0][-1] - 8 + end = np.where(a != -1)[0][-1] - 8 right_a = a.copy() - right_a[end - 3 : end] = "C" + right_a[end - 3 : end] = 1 # "C" a[start : start + 3] = left_a[start : start + 3] a[end - 3 : end] = right_a[end - 3 : end] - alignments = {"left": left_a, "right": right_a, "recombinant": a} - local_as = tmp_alignment_store(tmp_path, alignments) - date = "2020-03-01" - metadata_db = tmp_metadata_db(tmp_path, ["left", "right"], date) + alignments = {"left": left_a, "right": right_a} + ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", alignments, date=date) + ts = sc2ts.extend( - alignment_store=local_as, - metadata_db=metadata_db, - base_ts=base_ts, + dataset=ds.path, + base_ts=base_ts.path, date=date, - match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db"), + match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db").path, ) samples_strain = ts.metadata["sc2ts"]["samples_strain"] assert samples_strain[-2:] == ["left", "right"] - assert ts.num_mutations == base_ts.num_mutations + 6 assert ts.num_nodes == base_ts.num_nodes + 2 assert ts.num_edges == base_ts.num_edges + 2 + assert ts.num_mutations == base_ts.num_mutations + 6 left_node = ts.samples()[-2] right_node = ts.samples()[-1] @@ -266,24 +244,25 @@ def recombinant_example_2(tmp_path, fx_ts_map, fx_alignment_store): for j, mut_id in enumerate(np.where(ts.mutations_node == left_node)[0]): mut = ts.mutation(mut_id) assert mut.derived_state == "G" - assert ts.sites_position[mut.site] == start + j + assert ts.sites_position[mut.site] == start + j + 1 for j, mut_id in enumerate(np.where(ts.mutations_node == right_node)[0]): mut = ts.mutation(mut_id) assert mut.derived_state == "C" - assert ts.sites_position[mut.site] == end - 3 + j + assert ts.sites_position[mut.site] == end - 3 + j + 1 + + ts_path = tmp_path / "intermediate.ts" + ts.dump(ts_path) # Now run again with the recombinant of these two date = "2020-03-02" - metadata_db = tmp_metadata_db(tmp_path, ["recombinant"], date) + ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", {"recombinant": a}, date=date) rts = sc2ts.extend( - alignment_store=local_as, - metadata_db=metadata_db, - base_ts=ts, + dataset=ds.path, + base_ts=ts_path, date=date, - match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db"), + match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db").path, ) - return rts @@ -299,10 +278,11 @@ def fx_recombinant_example_1(tmp_path, fx_data_cache, fx_ts_map, fx_dataset): @pytest.fixture -def fx_recombinant_example_2(tmp_path, fx_data_cache, fx_ts_map, fx_alignment_store): +def fx_recombinant_example_2(tmp_path, fx_data_cache, fx_ts_map, fx_dataset): cache_path = fx_data_cache / "recombinant_ex2.ts" if not cache_path.exists(): print(f"Generating {cache_path}") - ts = recombinant_example_2(tmp_path, fx_ts_map, fx_alignment_store) + ds_cache_path = fx_data_cache / "recombinant_ex2_dataset.zarr" + ts = recombinant_example_2(tmp_path, fx_ts_map, fx_dataset, ds_cache_path) ts.dump(cache_path) return tskit.load(cache_path) diff --git a/tests/test_inference.py b/tests/test_inference.py index c24bb51..b3e240c 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -78,7 +78,6 @@ def test_get_recombinant_strains_ex1(self, fx_recombinant_example_1): d = sc2ts.get_recombinant_strains(fx_recombinant_example_1) assert d == {55: ["recombinant_example_1_0", "recombinant_example_1_1"]} - @pytest.mark.skip("Example broken by dataset") def test_get_recombinant_strains_ex2(self, fx_recombinant_example_2): d = sc2ts.get_recombinant_strains(fx_recombinant_example_2) assert d == {56: ["recombinant"]} @@ -1035,7 +1034,6 @@ def test_recombinant_example_1(self, fx_ts_map, fx_recombinant_example_1): assert row.parents == 2 assert row.causal_pango == {"Unknown": 2} - @pytest.mark.skip("Example broken by dataset") def test_recombinant_example_2(self, fx_ts_map, fx_recombinant_example_2): base_ts = fx_ts_map["2020-02-13"] date = "2020-03-01" From 80df1c58d7967d7bd5f6bd845570ca0d32131299 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 13 Dec 2024 14:24:42 +0000 Subject: [PATCH 2/6] Improve test settings in pyproject --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 1e19b81..42958b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,3 +58,7 @@ packages = ["sc2ts"] [tool.setuptools_scm] write_to = "sc2ts/_version.py" + +[tool.pytest.ini_options] +testpaths = "tests" +addopts = "--cov=sc2ts --cov-report term-missing" From 00b332dbf549db3ae178d9686e056123b02b26c9 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 13 Dec 2024 15:41:47 +0000 Subject: [PATCH 3/6] Add method to characterise recombinants --- sc2ts/inference.py | 35 +++++++++++++++++-- tests/conftest.py | 6 ++-- tests/test_inference.py | 75 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 111 insertions(+), 5 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 65a74f9..f3352be 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -364,8 +364,9 @@ class Sample: alignment_composition: Dict = None haplotype: List = None hmm_match: HmmMatch = None - hmm_reruns: Dict = dataclasses.field(default_factory=dict) + breakpoint_intervals: List = dataclasses.field(default_factory=list) flags: int = tskit.NODE_IS_SAMPLE + hmm_reruns: Dict = dataclasses.field(default_factory=dict) @property def is_recombinant(self): @@ -725,8 +726,9 @@ def _extend( num_threads=num_threads, memory_limit=memory_limit, ) - characterise_match_mutations(base_ts, samples) + characterise_recombinants(base_ts, samples) + for sample in unconditional_include_samples: # We want this sample to included unconditionally, so we set the # hmm cost to 0 < hmm_cost < hmm_cost_threshold. We use 0.5 @@ -840,6 +842,8 @@ def add_sample_to_tables(sample, tables, group_id=None): "alignment_composition": dict(sample.alignment_composition), "num_missing_sites": sample.num_missing_sites, } + if sample.is_recombinant: + sc2ts_md["breakpoint_intervals"] = sample.breakpoint_intervals if group_id is not None: sc2ts_md["group_id"] = group_id metadata = {**sample.metadata, "sc2ts": sc2ts_md} @@ -1681,6 +1685,33 @@ def get_closest_mutation(node, site_id): logger.debug(f"Characterised {num_mutations} mutations") +def characterise_recombinants(ts, samples): + """ + Update the metadata for any recombinants to add interval information to the metadata. + """ + recombinants = [s for s in samples if s.is_recombinant] + if len(recombinants) == 0: + return + logger.info(f"Characterising {len(recombinants)} recombinants") + + # NOTE: could make this more efficient by doing one call to genotype_matrix, + # but recombinants are rare so let's keep this simple + for s in recombinants: + parents = [seg.parent for seg in s.hmm_match.path] + # Can't have missing data here, so we're OK. + H = ts.genotype_matrix(samples=parents, isolated_as_missing=False).T + breakpoint_intervals = [] + for j in range(len(parents) - 1): + parents_differ = np.where(H[j] != H[j + 1])[0] + pos = ts.sites_position[parents_differ].astype(int) + right = s.hmm_match.path[j].right + right_index = np.searchsorted(pos, right) + assert pos[right_index] == right + left = pos[right_index - 1] + 1 + breakpoint_intervals.append((int(left), int(right))) + s.breakpoint_intervals = breakpoint_intervals + + def attach_tree( parent_ts, parent_tables, diff --git a/tests/conftest.py b/tests/conftest.py index 62157c8..dbe1c76 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -254,9 +254,11 @@ def recombinant_example_2(tmp_path, fx_ts_map, fx_dataset, ds_path): ts_path = tmp_path / "intermediate.ts" ts.dump(ts_path) - # Now run again with the recombinant of these two + # Now run again with the recombinant of these two, encoding the interval in the # name date = "2020-03-02" - ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", {"recombinant": a}, date=date) + left = start + 3 + 1 + right = end - 3 + 1 + ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", {f"recombinant_{left}:{right}": a}, date=date) rts = sc2ts.extend( dataset=ds.path, base_ts=ts_path, diff --git a/tests/test_inference.py b/tests/test_inference.py index b3e240c..ae077b9 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -80,7 +80,33 @@ def test_get_recombinant_strains_ex1(self, fx_recombinant_example_1): def test_get_recombinant_strains_ex2(self, fx_recombinant_example_2): d = sc2ts.get_recombinant_strains(fx_recombinant_example_2) - assert d == {56: ["recombinant"]} + assert d == {56: ["recombinant_114:29825"]} + + def test_recombinant_example_1(self, fx_recombinant_example_1): + ts = fx_recombinant_example_1 + samples_strain = ts.metadata["sc2ts"]["samples_strain"] + samples = ts.samples() + for s in ["recombinant_example_1_0", "recombinant_example_1_1"]: + u = samples[samples_strain.index(s)] + node = ts.node(u) + md = node.metadata["sc2ts"] + assert md["breakpoint_intervals"] == [[3788, 11083]] + assert md["hmm_match"]["path"] == [ + {"left": 0, "parent": 31, "right": 11083}, + {"left": 11083, "parent": 46, "right": 29904}, + ] + + def test_recombinant_example_2(self, fx_recombinant_example_2): + ts = fx_recombinant_example_2 + samples_strain = ts.metadata["sc2ts"]["samples_strain"] + u = ts.samples()[samples_strain.index("recombinant_114:29825")] + node = ts.node(u) + md = node.metadata["sc2ts"] + assert md["breakpoint_intervals"] == [[114, 29825]] + assert md["hmm_match"]["path"] == [ + {"left": 0, "parent": 53, "right": 29825}, + {"left": 29825, "parent": 54, "right": 29904}, + ] class TestSolveNumMismatches: @@ -1187,6 +1213,53 @@ def test_match_recombinant(self, fx_ts_map): assert m.path[1].right == ts.sequence_length +class TestCharacteriseRecombinants: + + def test_example_1(self, fx_ts_map): + ts, s = recombinant_example_1(fx_ts_map) + + interval_left = 3788 + interval_right = 11083 + left_parent = 31 + right_parent = 46 + + sc2ts.match_tsinfer( + samples=[s], + ts=ts, + num_mismatches=2, + mismatch_threshold=10, + ) + m = s.hmm_match + assert len(m.mutations) == 0 + assert len(m.path) == 2 + assert m.path[0].parent == left_parent + assert m.path[0].left == 0 + assert m.path[0].right == interval_right + assert m.path[1].parent == right_parent + assert m.path[1].left == interval_right + assert m.path[1].right == ts.sequence_length + + sc2ts.characterise_recombinants(ts, [s]) + assert s.breakpoint_intervals == [(interval_left, interval_right)] + + sc2ts.match_tsinfer( + samples=[s], + ts=ts, + num_mismatches=2, + mismatch_threshold=10, + mirror_coordinates=True, + ) + m = s.hmm_match + assert len(m.mutations) == 0 + assert len(m.path) == 2 + assert m.path[0].parent == left_parent + assert m.path[0].left == 0 + assert m.path[0].right == interval_left + assert m.path[1].parent == right_parent + assert m.path[1].left == interval_left + assert m.path[1].right == ts.sequence_length + + class TestMatchRecombinants: def test_example_1(self, fx_ts_map): ts, s = recombinant_example_1(fx_ts_map) From d44e570a8524487e6855000d100ee4aad9ad39de Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 13 Dec 2024 17:42:45 +0000 Subject: [PATCH 4/6] Trying to make example reusable --- sc2ts/inference.py | 2 +- tests/conftest.py | 87 +++++++++++++++++++++++++++++++++++++++++ tests/test_inference.py | 51 ++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 1 deletion(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index f3352be..f303261 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -1300,7 +1300,7 @@ def make_tsb(ts, num_alleles, mirror_coordinates=False): ts.edges_parent[index], ts.edges_child[index], ) - assert tsb.num_match_nodes == ts.num_nodes + # assert tsb.num_match_nodes == ts.num_nodes tsb.restore_mutations( ts.mutations_site, ts.mutations_node, derived_state, ts.mutations_parent diff --git a/tests/conftest.py b/tests/conftest.py index dbe1c76..9b23c82 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -267,6 +267,83 @@ def recombinant_example_2(tmp_path, fx_ts_map, fx_dataset, ds_path): ) return rts +def recombinant_example_3(tmp_path, fx_ts_map, fx_dataset, ds_path): + # Pick a distinct strain to be the root of our three new haplotypes added + # on the first day. + root_strain = "SRR11597116" + a = fx_dataset.haplotypes[root_strain] + base_ts = fx_ts_map["2020-02-13"] + # This sequence has a bunch of Ns at the start, so we have to go inwards + # from them to make sure we're not masking them out. + start = np.where(a != -1)[0][1] + 7 + left_a = a.copy() + left_a[start : start + 3] = 2 # "G" + + end = np.where(a != -1)[0][-1] - 8 + right_a = a.copy() + right_a[end - 3 : end] = 1 # "C" + + mid_a = a.copy() + mid_start = 15_000 + mid_end = 15_009 + mid_a[mid_start: mid_end] = 1 # "C" + + a = mid_a.copy() + a[start : start + 3] = left_a[start : start + 3] + a[end - 3 : end] = right_a[end - 3 : end] + + date = "2020-03-01" + alignments = {"left": left_a, "mid": mid_a, "right": right_a} + ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", alignments, date=date) + + ts = sc2ts.extend( + dataset=ds.path, + base_ts=base_ts.path, + date=date, + hmm_cost_threshold=15, + match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db").path, + ) + samples_strain = ts.metadata["sc2ts"]["samples_strain"] + assert samples_strain[-3:] == ["left", "mid", "right"] + assert ts.num_nodes == base_ts.num_nodes + 3 + assert ts.num_edges == base_ts.num_edges + 3 + assert ts.num_mutations == base_ts.num_mutations + 15 + + left_node, mid_node, right_node = ts.samples()[-3:] + + for j, mut_id in enumerate(np.where(ts.mutations_node == left_node)[0]): + mut = ts.mutation(mut_id) + assert mut.derived_state == "G" + assert ts.sites_position[mut.site] == start + j + 1 + + for j, mut_id in enumerate(np.where(ts.mutations_node == mid_node)[0]): + mut = ts.mutation(mut_id) + assert mut.derived_state == "C" + assert ts.sites_position[mut.site] == mid_start + j + 1 + + for j, mut_id in enumerate(np.where(ts.mutations_node == right_node)[0]): + mut = ts.mutation(mut_id) + assert mut.derived_state == "C" + assert ts.sites_position[mut.site] == end - 3 + j + 1 + + ts_path = tmp_path / "intermediate.ts" + ts.dump(ts_path) + + # Now run again with the recombinant of these three, encoding the intervals in the name + date = "2020-03-02" + left = start + 3 + 1 + right = end - 3 + 1 + name = f"recombinant_{left}:{mid_start + 1}:{mid_end + 1}:{right}" + ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", {name: a}, date=date) + rts = sc2ts.extend( + dataset=ds.path, + base_ts=ts_path, + date=date, + hmm_cost_threshold=15, + match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db").path, + ) + assert rts.num_samples == ts.num_samples + 1 + return rts @pytest.fixture def fx_recombinant_example_1(tmp_path, fx_data_cache, fx_ts_map, fx_dataset): @@ -288,3 +365,13 @@ def fx_recombinant_example_2(tmp_path, fx_data_cache, fx_ts_map, fx_dataset): ts = recombinant_example_2(tmp_path, fx_ts_map, fx_dataset, ds_cache_path) ts.dump(cache_path) return tskit.load(cache_path) + +@pytest.fixture +def fx_recombinant_example_3(tmp_path, fx_data_cache, fx_ts_map, fx_dataset): + cache_path = fx_data_cache / "recombinant_ex3.ts" + if not cache_path.exists(): + print(f"Generating {cache_path}") + ds_cache_path = fx_data_cache / "recombinant_ex3_dataset.zarr" + ts = recombinant_example_3(tmp_path, fx_ts_map, fx_dataset, ds_cache_path) + ts.dump(cache_path) + return tskit.load(cache_path) diff --git a/tests/test_inference.py b/tests/test_inference.py index ae077b9..4b468fd 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -49,6 +49,7 @@ def recombinant_example_1(ts_map): return ts, s + def tmp_metadata_db(tmp_path, strains, date): data = [] for strain in strains: @@ -108,6 +109,19 @@ def test_recombinant_example_2(self, fx_recombinant_example_2): {"left": 29825, "parent": 54, "right": 29904}, ] + def test_recombinant_example_3(self, fx_recombinant_example_3): + ts = fx_recombinant_example_3 + samples_strain = ts.metadata["sc2ts"]["samples_strain"] + u = ts.samples()[samples_strain.index("recombinant_114:15001:15010:29825")] + node = ts.node(u) + md = node.metadata["sc2ts"] + assert md["breakpoint_intervals"] == [[114, 15001], [15010, 29825]] + assert md["hmm_match"]["path"] == [ + {"left": 0, "parent": 53, "right": 15001}, + {"left": 15001, "parent": 54, "right": 29825}, + {"left": 29825, "parent": 55, "right": 29904}, + ] + class TestSolveNumMismatches: @pytest.mark.parametrize( @@ -1259,6 +1273,43 @@ def test_example_1(self, fx_ts_map): assert m.path[1].left == interval_left assert m.path[1].right == ts.sequence_length + def test_example_3(self, fx_recombinant_example_3): + ts = fx_recombinant_example_3 + strains = ts.metadata["sc2ts"]["samples_strain"] + assert strains[-1].startswith("recomb") + u = ts.samples()[-1] + h = ts.genotype_matrix(samples=[u]).T[0] + tables = ts.dump_tables() + keep_edges = ts.edges_child != u + tables.edges.keep_rows(keep_edges) + keep_nodes = np.ones(ts.num_nodes, dtype=bool) + tables.nodes[u] = tables.nodes[u].replace(flags=0) + tables.sort() + base_ts = tables.tree_sequence() + + alignment = np.full(int(ts.sequence_length), -1, dtype=np.int8) + alignment[ts.sites_position.astype(int)] = h + s = sc2ts.Sample("3way", "2020-02-14", haplotype=h.astype(np.int8)) + + sc2ts.match_tsinfer( + samples=[s], + ts=base_ts, + num_mismatches=2, + mismatch_threshold=10, + mirror_coordinates=False, + ) + m = s.hmm_match + print(m) + assert len(m.mutations) == 0 + assert len(m.path) == 2 + assert m.path[0].parent == left_parent + assert m.path[0].left == 0 + assert m.path[0].right == interval_left + assert m.path[1].parent == right_parent + assert m.path[1].left == interval_left + assert m.path[1].right == ts.sequence_length + + class TestMatchRecombinants: def test_example_1(self, fx_ts_map): From 28b1cb09ec03a337db687a061fdd2db5d1858983 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 13 Dec 2024 23:05:12 +0000 Subject: [PATCH 5/6] Fixup test example --- tests/test_inference.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 4b468fd..316ae88 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -49,7 +49,6 @@ def recombinant_example_1(ts_map): return ts, s - def tmp_metadata_db(tmp_path, strains, date): data = [] for strain in strains: @@ -1278,19 +1277,16 @@ def test_example_3(self, fx_recombinant_example_3): strains = ts.metadata["sc2ts"]["samples_strain"] assert strains[-1].startswith("recomb") u = ts.samples()[-1] - h = ts.genotype_matrix(samples=[u]).T[0] + h = ts.genotype_matrix(samples=[u], alleles=tuple(sc2ts.IUPAC_ALLELES)).T[0] tables = ts.dump_tables() - keep_edges = ts.edges_child != u + keep_edges = ts.edges_child < u tables.edges.keep_rows(keep_edges) keep_nodes = np.ones(ts.num_nodes, dtype=bool) tables.nodes[u] = tables.nodes[u].replace(flags=0) tables.sort() base_ts = tables.tree_sequence() - alignment = np.full(int(ts.sequence_length), -1, dtype=np.int8) - alignment[ts.sites_position.astype(int)] = h s = sc2ts.Sample("3way", "2020-02-14", haplotype=h.astype(np.int8)) - sc2ts.match_tsinfer( samples=[s], ts=base_ts, @@ -1298,17 +1294,22 @@ def test_example_3(self, fx_recombinant_example_3): mismatch_threshold=10, mirror_coordinates=False, ) + sc2ts.characterise_recombinants(ts, [s]) m = s.hmm_match - print(m) - assert len(m.mutations) == 0 - assert len(m.path) == 2 - assert m.path[0].parent == left_parent - assert m.path[0].left == 0 - assert m.path[0].right == interval_left - assert m.path[1].parent == right_parent - assert m.path[1].left == interval_left - assert m.path[1].right == ts.sequence_length - + assert m.parents == [53, 54, 55] + assert m.breakpoints == [0, 15001, 29825, 29904] + assert s.breakpoint_intervals == [(114, 15001), (15010, 29825)] + # Verify that these breakpoints correspond to the reverse-direction HMM + sc2ts.match_tsinfer( + samples=[s], + ts=base_ts, + num_mismatches=2, + mismatch_threshold=10, + mirror_coordinates=True, + ) + m = s.hmm_match + assert m.parents == [53, 54, 55] + assert m.breakpoints == [0, 114, 15010, 29904] class TestMatchRecombinants: From b357d4d750436c0216b100b1ce4603252a3a1500 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 13 Dec 2024 23:12:16 +0000 Subject: [PATCH 6/6] Remove hmm_reruns and other infrastructure for remathing recombinants Closes #159 --- sc2ts/cli.py | 98 ----------------------------------------- sc2ts/inference.py | 43 +----------------- tests/test_cli.py | 38 ---------------- tests/test_inference.py | 74 +------------------------------ tests/test_utils.py | 18 -------- 5 files changed, 2 insertions(+), 269 deletions(-) delete mode 100644 tests/test_utils.py diff --git a/sc2ts/cli.py b/sc2ts/cli.py index 7c37a68..79af3b8 100644 --- a/sc2ts/cli.py +++ b/sc2ts/cli.py @@ -633,103 +633,6 @@ def find_previous_date_path(date, path_pattern): return path -@click.command() -@click.argument("dataset", type=click.Path(exists=True, dir_okay=False)) -@click.argument("ts", type=click.Path(exists=True, dir_okay=False)) -@click.argument("path_pattern") -@num_mismatches -@click.option( - "--num-threads", - default=0, - type=int, - help="Number of match threads (default to one)", -) -@click.option("--progress/--no-progress", default=True) -@click.option("-v", "--verbose", count=True) -@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False)) -def rematch_recombinants( - dataset, - ts, - path_pattern, - num_mismatches, - num_threads, - progress, - verbose, - log_file, -): - setup_logging(verbose, log_file) - ts = tszip.load(ts) - # This is a map of recombinant node to the samples involved in - # the original causal sample group. - recombinant_strains = sc2ts.get_recombinant_strains(ts) - logger.info( - f"Got {len(recombinant_strains)} recombinants and " - f"{sum(len(v) for v in recombinant_strains.values())} strains" - ) - - # Map recombinants to originating date - recombinant_to_path = {} - strain_to_recombinant = {} - all_strains = [] - for u, strains in recombinant_strains.items(): - date_added = ts.node(u).metadata["sc2ts"]["date_added"] - base_ts_path = find_previous_date_path(date_added, path_pattern) - recombinant_to_path[u] = base_ts_path - for strain in strains: - strain_to_recombinant[strain] = u - all_strains.append(strain) - - ds = sc2ts.Dataset(dataset) - progress_title = "Recomb" - samples = sc2ts.preprocess( - all_strains, - datset=ds, - show_progress=progress, - progress_title=progress_title, - keep_sites=ts.sites_position.astype(int), - num_workers=num_threads, - ) - - recombinant_to_samples = collections.defaultdict(list) - for sample in samples: - if sample.haplotype is None: - raise ValueError(f"No alignment stored for {sample.strain}") - recombinant = strain_to_recombinant[sample.strain] - recombinant_to_samples[recombinant].append(sample) - - work = [] - for recombinant, samples in recombinant_to_samples.items(): - for direction in ["forward", "reverse"]: - work.append( - MatchWork( - recombinant_to_path[recombinant], - samples, - num_mismatches=num_mismatches, - direction=direction, - ) - ) - - bar = sc2ts.get_progress(None, progress_title, "HMM", progress, total=len(work)) - - def output(hmm_runs): - bar.update() - for run in hmm_runs: - print(run.asjson()) - - results = [] - if num_threads == 0: - for w in work: - hmm_runs = _match_worker(w) - output(hmm_runs) - else: - with cf.ProcessPoolExecutor(num_threads) as executor: - futures = [executor.submit(_match_worker, w) for w in work] - for future in cf.as_completed(futures): - hmm_runs = future.result() - output(hmm_runs) - bar.close() - - @click.version_option(core.__version__) @click.group() def cli(): @@ -747,5 +650,4 @@ def cli(): cli.add_command(infer) cli.add_command(validate) cli.add_command(_match) -cli.add_command(rematch_recombinants) cli.add_command(tally_lineages) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index f303261..488b127 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -366,7 +366,6 @@ class Sample: hmm_match: HmmMatch = None breakpoint_intervals: List = dataclasses.field(default_factory=list) flags: int = tskit.NODE_IS_SAMPLE - hmm_reruns: Dict = dataclasses.field(default_factory=dict) @property def is_recombinant(self): @@ -382,46 +381,7 @@ def num_deletion_sites(self): def summary(self): hmm_match = "No match" if self.hmm_match is None else self.hmm_match.summary() - s = f"{self.strain} {self.date} {self.pango} {hmm_match}" - for name, hmm_match in self.hmm_reruns.items(): - s += f"; {name}: {hmm_match.summary()}" - return s - - -# TODO not clear if we still need this as mirroring is done differently now. -# Remove if we don't have any issues with running the HMM in reverse -def pad_sites(ts): - """ - Fill in missing sites with the reference state. - """ - ref = core.get_reference_sequence() - missing_sites = set(np.arange(1, len(ref))) - missing_sites -= set(ts.sites_position.astype(int)) - tables = ts.dump_tables() - for pos in missing_sites: - tables.sites.add_row(pos, ref[pos]) - tables.sort() - return tables.tree_sequence() - - -# TODO remove this -def match_recombinants( - samples, base_ts, num_mismatches, show_progress=False, num_threads=None -): - for hmm_pass in ["forward", "reverse", "no_recombination"]: - logger.info(f"Running {hmm_pass} pass for {len(samples)} recombinants") - match_tsinfer( - samples=samples, - ts=base_ts, - num_mismatches=1000 if hmm_pass == "no_recombination" else num_mismatches, - mismatch_threshold=100, - num_threads=num_threads, - show_progress=show_progress, - mirror_coordinates=hmm_pass == "reverse", - ) - - for sample in samples: - sample.hmm_reruns[hmm_pass] = sample.hmm_match + return f"{self.strain} {self.date} {self.pango} {hmm_match}" def match_samples( @@ -838,7 +798,6 @@ def update_top_level_metadata(ts, date, retro_groups, samples): def add_sample_to_tables(sample, tables, group_id=None): sc2ts_md = { "hmm_match": sample.hmm_match.asdict(), - "hmm_reruns": {k: m.asdict() for k, m in sample.hmm_reruns.items()}, "alignment_composition": dict(sample.alignment_composition), "num_missing_sites": sample.num_missing_sites, } diff --git a/tests/test_cli.py b/tests/test_cli.py index 563773b..7f0e603 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -369,44 +369,6 @@ def test_multiple_override(self, tmp_path, fx_ts_map, fx_dataset): assert ts.num_samples == 0 -@pytest.mark.skip("Broken by dataset") -class TestRunRematchRecombinants: - - @pytest.mark.parametrize("num_threads", [0, 1, 2]) - def test_defaults( - self, tmp_path, fx_recombinant_example_1, fx_data_cache, num_threads - ): - ts_path = fx_data_cache / "recombinant_ex1.ts" - as_path = fx_data_cache / "recombinant_ex1_alignments.db" - pattern = str(fx_data_cache) + "/{}.ts" - runner = ct.CliRunner(mix_stderr=False) - cmd = ( - f"rematch-recombinants {as_path} {ts_path} {pattern} " - f"--num-threads={num_threads}" - ) - result = runner.invoke( - cli.cli, - cmd, - catch_exceptions=False, - ) - assert result.exit_code == 0 - lines = result.stdout.splitlines() - assert len(lines) == 4 - results = collections.defaultdict(list) - for line in lines: - d = json.loads(line) - results[d["strain"]].append(result) - - assert len(results) == 2 - assert set(results.keys()) == { - "recombinant_example_1_0", - "recombinant_example_1_1", - } - - assert len(results["recombinant_example_1_0"]) == 2 - assert len(results["recombinant_example_1_1"]) == 2 - - class TestValidate: @pytest.mark.parametrize("date", ["2020-01-01", "2020-02-11"]) diff --git a/tests/test_inference.py b/tests/test_inference.py index 316ae88..f7d8aab 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1022,7 +1022,6 @@ def test_recombinant_example_1(self, fx_ts_map, fx_recombinant_example_1): {"left": bp, "parent": right_parent, "right": 29904}, ], } - assert smd["hmm_reruns"] == {} sample = ts.node(ts.samples()[-1]) smd = sample.metadata["sc2ts"] @@ -1036,7 +1035,6 @@ def test_recombinant_example_1(self, fx_ts_map, fx_recombinant_example_1): {"left": bp, "parent": right_parent, "right": 29904}, ], } - assert smd["hmm_reruns"] == {} recomb_node = ts.node(ts.num_nodes - 1) assert recomb_node.flags == sc2ts.NODE_IS_RECOMBINANT @@ -1078,7 +1076,7 @@ def test_recombinant_example_2(self, fx_ts_map, fx_recombinant_example_2): date = "2020-03-01" rts = fx_recombinant_example_2 samples_strain = rts.metadata["sc2ts"]["samples_strain"] - assert samples_strain[-3:] == ["left", "right", "recombinant"] + assert samples_strain[-3:] == ["left", "right", "recombinant_114:29825"] sample = rts.node(rts.samples()[-1]) smd = sample.metadata["sc2ts"] @@ -1090,8 +1088,6 @@ def test_recombinant_example_2(self, fx_ts_map, fx_recombinant_example_2): ], } - assert smd["hmm_reruns"] == {} - def test_all_As(self, tmp_path, fx_ts_map, fx_dataset): # Same as the recombinant_example_1() function above # Just to get something that looks like an alignment easily @@ -1310,71 +1306,3 @@ def test_example_3(self, fx_recombinant_example_3): m = s.hmm_match assert m.parents == [53, 54, 55] assert m.breakpoints == [0, 114, 15010, 29904] - - -class TestMatchRecombinants: - def test_example_1(self, fx_ts_map): - ts, s = recombinant_example_1(fx_ts_map) - - sc2ts.match_recombinants( - samples=[s], - base_ts=ts, - num_mismatches=2, - num_threads=0, - ) - left_parent = 31 - right_parent = 46 - interval_right = 11083 - - m = s.hmm_reruns["forward"] - assert len(m.mutations) == 0 - assert len(m.path) == 2 - assert m.path[0].parent == left_parent - assert m.path[0].left == 0 - assert m.path[0].right == interval_right - assert m.path[1].parent == right_parent - assert m.path[1].left == interval_right - assert m.path[1].right == ts.sequence_length - - interval_left = 3788 - m = s.hmm_reruns["reverse"] - assert len(m.mutations) == 0 - assert len(m.path) == 2 - assert m.path[0].left == 0 - assert m.path[0].right == interval_left - assert m.path[0].parent == left_parent - assert m.path[1].parent == right_parent - assert m.path[1].left == interval_left - assert m.path[1].right == ts.sequence_length - - m = s.hmm_reruns["no_recombination"] - # It seems that we can choose either the left or right parent - # arbitrarily :shrug: - assert len(m.mutations) == 3 - assert m.mutation_summary() in [ - "[A871G, A3027G, C3787T]", - "[T11083G, C15324T, C29303T]", - ] - assert len(m.path) == 1 - assert m.path[0].parent in [left_parent, right_parent] - assert m.path[0].left == 0 - assert m.path[0].right == ts.sequence_length - - assert "no_recombination" in s.summary() - - def test_all_As(self, fx_ts_map): - ts = fx_ts_map["2020-02-13"] - h = np.zeros(ts.num_sites, dtype=np.int8) - s = sc2ts.Sample("zerotype", "2020-02-14", haplotype=h) - - sc2ts.match_recombinants( - samples=[s], - base_ts=ts, - num_mismatches=3, - num_threads=0, - ) - assert len(s.hmm_reruns) == 3 - num_mutations = [] - for hmm_match in s.hmm_reruns.values(): - assert len(hmm_match.path) == 1 - assert len(hmm_match.mutations) == 20943 diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 3b0eb91..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,18 +0,0 @@ -import numpy as np -import pytest - -import sc2ts -import sc2ts.utils as utils -import util - - -class TestPadSites: - def check_site_padding(self, ts): - ts = sc2ts.pad_sites(ts) - ref = sc2ts.core.get_reference_sequence(as_array=True) - assert ts.num_sites == len(ref) - 1 - ancestral_state = ts.tables.sites.ancestral_state.view("S1").astype(str) - assert np.all(ancestral_state == ref[1:]) - - def test_initial(self): - self.check_site_padding(sc2ts.initial_ts())