diff --git a/sc2ts/dataset.py b/sc2ts/dataset.py index 49c2749..722baa3 100644 --- a/sc2ts/dataset.py +++ b/sc2ts/dataset.py @@ -9,6 +9,7 @@ import logging import pathlib +import tskit import tqdm import pandas as pd import zarr @@ -309,6 +310,44 @@ def reorder(self, path, additional_fields=list(), show_progress=False): index = np.lexsort([self.metadata.fields[f] for f in sort_key[::-1]]) self.copy(path, sample_id=sample_id[index], show_progress=show_progress) + def map_deletions_to_ts(self, ts, start, end): + """ + Map a region of the dataset onto the specified tree sequence, + adding deletions to the tree sequence where necessary. + The tree sequence samples must all be present in the dataset. + A new tree sequence is returned whose original mutations in the + region have been removed and replaced using parsimony + """ + sample_id = ts.metadata["sc2ts"]["samples_strain"] + if sample_id[0].startswith("Wuhan"): + # Note: a lot of fiddling around here is due to the potential for sample 0 + # to be the reference, which is not in the viridian dataset. + sample_id = sample_id[1:] + tables = ts.dump_tables() + n_tm = ts.nodes_time + tree = ts.first() + del_sites = [ts.site(position=p).id for p in range(start, end)] + tables.mutations.keep_rows(np.logical_not(np.isin(ts.mutations_site, del_sites))) + for var in self.variants(sample_id, np.arange(start, end)): + tree.seek(var.position) + site = ts.site(position=var.position) + # treat non-nucleotide alleles (other IUPAC codes) as missing , but keep "-" + keep = np.isin(var.alleles, np.array(['A', 'C', 'G', 'T', '-'])) + g = var.genotypes.copy() + g[keep[g] == False] = tskit.MISSING_DATA + anc = site.ancestral_state + # Pad the start with the anc state (non-viridian) Wuhan sample, so add that + if len(g) < ts.num_samples: + g = np.append([list(var.alleles).index(anc)], g) + _, mutations = tree.map_mutations(g, list(var.alleles), ancestral_state=anc) + m_id_map = {tskit.NULL: tskit.NULL} + for list_id, m in enumerate(mutations): + m_id_map[list_id] = tables.mutations.append( + m.replace(site=site.id, parent=m_id_map[m.parent], time=n_tm[m.node]) + ) + tables.sort() + return tables.tree_sequence() + @staticmethod def new(path, samples_chunk_size=None, variants_chunk_size=None): diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 2ee1e93..e565acf 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -269,6 +269,9 @@ def test_examples(self, fx_dataset): ], ) + @pytest.mark.skip("Not implemented") + def test_map_deletions_to_ts(self, fx_dataset): + pass class TestDatasetAlignments: