diff --git a/pyproject.toml b/pyproject.toml index 1e19b81..42958b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,3 +58,7 @@ packages = ["sc2ts"] [tool.setuptools_scm] write_to = "sc2ts/_version.py" + +[tool.pytest.ini_options] +testpaths = "tests" +addopts = "--cov=sc2ts --cov-report term-missing" diff --git a/sc2ts/cli.py b/sc2ts/cli.py index 7c37a68..79af3b8 100644 --- a/sc2ts/cli.py +++ b/sc2ts/cli.py @@ -633,103 +633,6 @@ def find_previous_date_path(date, path_pattern): return path -@click.command() -@click.argument("dataset", type=click.Path(exists=True, dir_okay=False)) -@click.argument("ts", type=click.Path(exists=True, dir_okay=False)) -@click.argument("path_pattern") -@num_mismatches -@click.option( - "--num-threads", - default=0, - type=int, - help="Number of match threads (default to one)", -) -@click.option("--progress/--no-progress", default=True) -@click.option("-v", "--verbose", count=True) -@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False)) -def rematch_recombinants( - dataset, - ts, - path_pattern, - num_mismatches, - num_threads, - progress, - verbose, - log_file, -): - setup_logging(verbose, log_file) - ts = tszip.load(ts) - # This is a map of recombinant node to the samples involved in - # the original causal sample group. - recombinant_strains = sc2ts.get_recombinant_strains(ts) - logger.info( - f"Got {len(recombinant_strains)} recombinants and " - f"{sum(len(v) for v in recombinant_strains.values())} strains" - ) - - # Map recombinants to originating date - recombinant_to_path = {} - strain_to_recombinant = {} - all_strains = [] - for u, strains in recombinant_strains.items(): - date_added = ts.node(u).metadata["sc2ts"]["date_added"] - base_ts_path = find_previous_date_path(date_added, path_pattern) - recombinant_to_path[u] = base_ts_path - for strain in strains: - strain_to_recombinant[strain] = u - all_strains.append(strain) - - ds = sc2ts.Dataset(dataset) - progress_title = "Recomb" - samples = sc2ts.preprocess( - all_strains, - datset=ds, - show_progress=progress, - progress_title=progress_title, - keep_sites=ts.sites_position.astype(int), - num_workers=num_threads, - ) - - recombinant_to_samples = collections.defaultdict(list) - for sample in samples: - if sample.haplotype is None: - raise ValueError(f"No alignment stored for {sample.strain}") - recombinant = strain_to_recombinant[sample.strain] - recombinant_to_samples[recombinant].append(sample) - - work = [] - for recombinant, samples in recombinant_to_samples.items(): - for direction in ["forward", "reverse"]: - work.append( - MatchWork( - recombinant_to_path[recombinant], - samples, - num_mismatches=num_mismatches, - direction=direction, - ) - ) - - bar = sc2ts.get_progress(None, progress_title, "HMM", progress, total=len(work)) - - def output(hmm_runs): - bar.update() - for run in hmm_runs: - print(run.asjson()) - - results = [] - if num_threads == 0: - for w in work: - hmm_runs = _match_worker(w) - output(hmm_runs) - else: - with cf.ProcessPoolExecutor(num_threads) as executor: - futures = [executor.submit(_match_worker, w) for w in work] - for future in cf.as_completed(futures): - hmm_runs = future.result() - output(hmm_runs) - bar.close() - - @click.version_option(core.__version__) @click.group() def cli(): @@ -747,5 +650,4 @@ def cli(): cli.add_command(infer) cli.add_command(validate) cli.add_command(_match) -cli.add_command(rematch_recombinants) cli.add_command(tally_lineages) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 65a74f9..488b127 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -364,7 +364,7 @@ class Sample: alignment_composition: Dict = None haplotype: List = None hmm_match: HmmMatch = None - hmm_reruns: Dict = dataclasses.field(default_factory=dict) + breakpoint_intervals: List = dataclasses.field(default_factory=list) flags: int = tskit.NODE_IS_SAMPLE @property @@ -381,46 +381,7 @@ def num_deletion_sites(self): def summary(self): hmm_match = "No match" if self.hmm_match is None else self.hmm_match.summary() - s = f"{self.strain} {self.date} {self.pango} {hmm_match}" - for name, hmm_match in self.hmm_reruns.items(): - s += f"; {name}: {hmm_match.summary()}" - return s - - -# TODO not clear if we still need this as mirroring is done differently now. -# Remove if we don't have any issues with running the HMM in reverse -def pad_sites(ts): - """ - Fill in missing sites with the reference state. - """ - ref = core.get_reference_sequence() - missing_sites = set(np.arange(1, len(ref))) - missing_sites -= set(ts.sites_position.astype(int)) - tables = ts.dump_tables() - for pos in missing_sites: - tables.sites.add_row(pos, ref[pos]) - tables.sort() - return tables.tree_sequence() - - -# TODO remove this -def match_recombinants( - samples, base_ts, num_mismatches, show_progress=False, num_threads=None -): - for hmm_pass in ["forward", "reverse", "no_recombination"]: - logger.info(f"Running {hmm_pass} pass for {len(samples)} recombinants") - match_tsinfer( - samples=samples, - ts=base_ts, - num_mismatches=1000 if hmm_pass == "no_recombination" else num_mismatches, - mismatch_threshold=100, - num_threads=num_threads, - show_progress=show_progress, - mirror_coordinates=hmm_pass == "reverse", - ) - - for sample in samples: - sample.hmm_reruns[hmm_pass] = sample.hmm_match + return f"{self.strain} {self.date} {self.pango} {hmm_match}" def match_samples( @@ -725,8 +686,9 @@ def _extend( num_threads=num_threads, memory_limit=memory_limit, ) - characterise_match_mutations(base_ts, samples) + characterise_recombinants(base_ts, samples) + for sample in unconditional_include_samples: # We want this sample to included unconditionally, so we set the # hmm cost to 0 < hmm_cost < hmm_cost_threshold. We use 0.5 @@ -836,10 +798,11 @@ def update_top_level_metadata(ts, date, retro_groups, samples): def add_sample_to_tables(sample, tables, group_id=None): sc2ts_md = { "hmm_match": sample.hmm_match.asdict(), - "hmm_reruns": {k: m.asdict() for k, m in sample.hmm_reruns.items()}, "alignment_composition": dict(sample.alignment_composition), "num_missing_sites": sample.num_missing_sites, } + if sample.is_recombinant: + sc2ts_md["breakpoint_intervals"] = sample.breakpoint_intervals if group_id is not None: sc2ts_md["group_id"] = group_id metadata = {**sample.metadata, "sc2ts": sc2ts_md} @@ -1296,7 +1259,7 @@ def make_tsb(ts, num_alleles, mirror_coordinates=False): ts.edges_parent[index], ts.edges_child[index], ) - assert tsb.num_match_nodes == ts.num_nodes + # assert tsb.num_match_nodes == ts.num_nodes tsb.restore_mutations( ts.mutations_site, ts.mutations_node, derived_state, ts.mutations_parent @@ -1681,6 +1644,33 @@ def get_closest_mutation(node, site_id): logger.debug(f"Characterised {num_mutations} mutations") +def characterise_recombinants(ts, samples): + """ + Update the metadata for any recombinants to add interval information to the metadata. + """ + recombinants = [s for s in samples if s.is_recombinant] + if len(recombinants) == 0: + return + logger.info(f"Characterising {len(recombinants)} recombinants") + + # NOTE: could make this more efficient by doing one call to genotype_matrix, + # but recombinants are rare so let's keep this simple + for s in recombinants: + parents = [seg.parent for seg in s.hmm_match.path] + # Can't have missing data here, so we're OK. + H = ts.genotype_matrix(samples=parents, isolated_as_missing=False).T + breakpoint_intervals = [] + for j in range(len(parents) - 1): + parents_differ = np.where(H[j] != H[j + 1])[0] + pos = ts.sites_position[parents_differ].astype(int) + right = s.hmm_match.path[j].right + right_index = np.searchsorted(pos, right) + assert pos[right_index] == right + left = pos[right_index - 1] + 1 + breakpoint_intervals.append((int(left), int(right))) + s.breakpoint_intervals = breakpoint_intervals + + def attach_tree( parent_ts, parent_tables, diff --git a/tests/conftest.py b/tests/conftest.py index faaa809..9b23c82 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -68,7 +68,6 @@ def fx_raw_viridian_metadata_tsv(): @pytest.fixture def fx_metadata_df(fx_metadata_tsv): - return read_metadata_df(fx_metadata_tsv) @@ -138,7 +137,7 @@ def fx_ts_map(tmp_path, fx_data_cache, fx_dataset, fx_match_db): # Load the ts from file to get the provenance data extra_kwargs = {} if date == dates[-1]: - # Force a bunch of retro groups in on the last day + # Force a bunch of retro groups in on the last day extra_kwargs = { "min_group_size": 1, "min_root_mutations": 0, @@ -150,7 +149,7 @@ def fx_ts_map(tmp_path, fx_data_cache, fx_dataset, fx_match_db): base_ts=cache_path, date=date, match_db=fx_match_db.path, - **extra_kwargs + **extra_kwargs, ) print( f"INFERRED {date} nodes={last_ts.num_nodes} mutations={last_ts.num_mutations}" @@ -166,25 +165,6 @@ def fx_ts_map(tmp_path, fx_data_cache, fx_dataset, fx_match_db): return d -def tmp_alignment_store(tmp_path, alignments): - path = tmp_path / "synthetic_alignments.db" - alignment_db = sc2ts.AlignmentStore(path, mode="rw") - alignment_db.append(alignments) - return alignment_db - - -def tmp_metadata_db(tmp_path, strains, date): - data = [] - for strain in strains: - data.append({"strain": strain, "date": date}) - df = pd.DataFrame(data) - csv_path = tmp_path / "metadata.csv" - df.to_csv(csv_path) - db_path = tmp_path / "metadata.db" - sc2ts.MetadataDb.import_csv(csv_path, db_path, sep=",") - return sc2ts.MetadataDb(db_path) - - def recombinant_alignments(dataset): """ Generate some recombinant alignments from existing haplotypes @@ -223,42 +203,40 @@ def recombinant_example_1(tmp_path, fx_ts_map, fx_dataset, ds_path): return ts -def recombinant_example_2(tmp_path, fx_ts_map, fx_alignment_store): +def recombinant_example_2(tmp_path, fx_ts_map, fx_dataset, ds_path): # Pick a distinct strain to be the root of our two new haplotypes added # on the first day. root_strain = "SRR11597116" - a = fx_alignment_store[root_strain] + a = fx_dataset.haplotypes[root_strain] base_ts = fx_ts_map["2020-02-13"] # This sequence has a bunch of Ns at the start, so we have to go inwards # from them to make sure we're not masking them out. - start = np.where(a != "N")[0][1] + 7 + start = np.where(a != -1)[0][1] + 7 left_a = a.copy() - left_a[start : start + 3] = "G" + left_a[start : start + 3] = 2 # "G" - end = np.where(a != "N")[0][-1] - 8 + end = np.where(a != -1)[0][-1] - 8 right_a = a.copy() - right_a[end - 3 : end] = "C" + right_a[end - 3 : end] = 1 # "C" a[start : start + 3] = left_a[start : start + 3] a[end - 3 : end] = right_a[end - 3 : end] - alignments = {"left": left_a, "right": right_a, "recombinant": a} - local_as = tmp_alignment_store(tmp_path, alignments) - date = "2020-03-01" - metadata_db = tmp_metadata_db(tmp_path, ["left", "right"], date) + alignments = {"left": left_a, "right": right_a} + ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", alignments, date=date) + ts = sc2ts.extend( - alignment_store=local_as, - metadata_db=metadata_db, - base_ts=base_ts, + dataset=ds.path, + base_ts=base_ts.path, date=date, - match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db"), + match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db").path, ) samples_strain = ts.metadata["sc2ts"]["samples_strain"] assert samples_strain[-2:] == ["left", "right"] - assert ts.num_mutations == base_ts.num_mutations + 6 assert ts.num_nodes == base_ts.num_nodes + 2 assert ts.num_edges == base_ts.num_edges + 2 + assert ts.num_mutations == base_ts.num_mutations + 6 left_node = ts.samples()[-2] right_node = ts.samples()[-1] @@ -266,26 +244,106 @@ def recombinant_example_2(tmp_path, fx_ts_map, fx_alignment_store): for j, mut_id in enumerate(np.where(ts.mutations_node == left_node)[0]): mut = ts.mutation(mut_id) assert mut.derived_state == "G" - assert ts.sites_position[mut.site] == start + j + assert ts.sites_position[mut.site] == start + j + 1 for j, mut_id in enumerate(np.where(ts.mutations_node == right_node)[0]): mut = ts.mutation(mut_id) assert mut.derived_state == "C" - assert ts.sites_position[mut.site] == end - 3 + j + assert ts.sites_position[mut.site] == end - 3 + j + 1 - # Now run again with the recombinant of these two + ts_path = tmp_path / "intermediate.ts" + ts.dump(ts_path) + + # Now run again with the recombinant of these two, encoding the interval in the # name date = "2020-03-02" - metadata_db = tmp_metadata_db(tmp_path, ["recombinant"], date) + left = start + 3 + 1 + right = end - 3 + 1 + ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", {f"recombinant_{left}:{right}": a}, date=date) rts = sc2ts.extend( - alignment_store=local_as, - metadata_db=metadata_db, - base_ts=ts, + dataset=ds.path, + base_ts=ts_path, date=date, - match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db"), + match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db").path, ) - return rts +def recombinant_example_3(tmp_path, fx_ts_map, fx_dataset, ds_path): + # Pick a distinct strain to be the root of our three new haplotypes added + # on the first day. + root_strain = "SRR11597116" + a = fx_dataset.haplotypes[root_strain] + base_ts = fx_ts_map["2020-02-13"] + # This sequence has a bunch of Ns at the start, so we have to go inwards + # from them to make sure we're not masking them out. + start = np.where(a != -1)[0][1] + 7 + left_a = a.copy() + left_a[start : start + 3] = 2 # "G" + + end = np.where(a != -1)[0][-1] - 8 + right_a = a.copy() + right_a[end - 3 : end] = 1 # "C" + + mid_a = a.copy() + mid_start = 15_000 + mid_end = 15_009 + mid_a[mid_start: mid_end] = 1 # "C" + + a = mid_a.copy() + a[start : start + 3] = left_a[start : start + 3] + a[end - 3 : end] = right_a[end - 3 : end] + + date = "2020-03-01" + alignments = {"left": left_a, "mid": mid_a, "right": right_a} + ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", alignments, date=date) + + ts = sc2ts.extend( + dataset=ds.path, + base_ts=base_ts.path, + date=date, + hmm_cost_threshold=15, + match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db").path, + ) + samples_strain = ts.metadata["sc2ts"]["samples_strain"] + assert samples_strain[-3:] == ["left", "mid", "right"] + assert ts.num_nodes == base_ts.num_nodes + 3 + assert ts.num_edges == base_ts.num_edges + 3 + assert ts.num_mutations == base_ts.num_mutations + 15 + + left_node, mid_node, right_node = ts.samples()[-3:] + + for j, mut_id in enumerate(np.where(ts.mutations_node == left_node)[0]): + mut = ts.mutation(mut_id) + assert mut.derived_state == "G" + assert ts.sites_position[mut.site] == start + j + 1 + + for j, mut_id in enumerate(np.where(ts.mutations_node == mid_node)[0]): + mut = ts.mutation(mut_id) + assert mut.derived_state == "C" + assert ts.sites_position[mut.site] == mid_start + j + 1 + + for j, mut_id in enumerate(np.where(ts.mutations_node == right_node)[0]): + mut = ts.mutation(mut_id) + assert mut.derived_state == "C" + assert ts.sites_position[mut.site] == end - 3 + j + 1 + + ts_path = tmp_path / "intermediate.ts" + ts.dump(ts_path) + + # Now run again with the recombinant of these three, encoding the intervals in the name + date = "2020-03-02" + left = start + 3 + 1 + right = end - 3 + 1 + name = f"recombinant_{left}:{mid_start + 1}:{mid_end + 1}:{right}" + ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", {name: a}, date=date) + rts = sc2ts.extend( + dataset=ds.path, + base_ts=ts_path, + date=date, + hmm_cost_threshold=15, + match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db").path, + ) + assert rts.num_samples == ts.num_samples + 1 + return rts @pytest.fixture def fx_recombinant_example_1(tmp_path, fx_data_cache, fx_ts_map, fx_dataset): @@ -299,10 +357,21 @@ def fx_recombinant_example_1(tmp_path, fx_data_cache, fx_ts_map, fx_dataset): @pytest.fixture -def fx_recombinant_example_2(tmp_path, fx_data_cache, fx_ts_map, fx_alignment_store): +def fx_recombinant_example_2(tmp_path, fx_data_cache, fx_ts_map, fx_dataset): cache_path = fx_data_cache / "recombinant_ex2.ts" if not cache_path.exists(): print(f"Generating {cache_path}") - ts = recombinant_example_2(tmp_path, fx_ts_map, fx_alignment_store) + ds_cache_path = fx_data_cache / "recombinant_ex2_dataset.zarr" + ts = recombinant_example_2(tmp_path, fx_ts_map, fx_dataset, ds_cache_path) + ts.dump(cache_path) + return tskit.load(cache_path) + +@pytest.fixture +def fx_recombinant_example_3(tmp_path, fx_data_cache, fx_ts_map, fx_dataset): + cache_path = fx_data_cache / "recombinant_ex3.ts" + if not cache_path.exists(): + print(f"Generating {cache_path}") + ds_cache_path = fx_data_cache / "recombinant_ex3_dataset.zarr" + ts = recombinant_example_3(tmp_path, fx_ts_map, fx_dataset, ds_cache_path) ts.dump(cache_path) return tskit.load(cache_path) diff --git a/tests/test_cli.py b/tests/test_cli.py index 563773b..7f0e603 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -369,44 +369,6 @@ def test_multiple_override(self, tmp_path, fx_ts_map, fx_dataset): assert ts.num_samples == 0 -@pytest.mark.skip("Broken by dataset") -class TestRunRematchRecombinants: - - @pytest.mark.parametrize("num_threads", [0, 1, 2]) - def test_defaults( - self, tmp_path, fx_recombinant_example_1, fx_data_cache, num_threads - ): - ts_path = fx_data_cache / "recombinant_ex1.ts" - as_path = fx_data_cache / "recombinant_ex1_alignments.db" - pattern = str(fx_data_cache) + "/{}.ts" - runner = ct.CliRunner(mix_stderr=False) - cmd = ( - f"rematch-recombinants {as_path} {ts_path} {pattern} " - f"--num-threads={num_threads}" - ) - result = runner.invoke( - cli.cli, - cmd, - catch_exceptions=False, - ) - assert result.exit_code == 0 - lines = result.stdout.splitlines() - assert len(lines) == 4 - results = collections.defaultdict(list) - for line in lines: - d = json.loads(line) - results[d["strain"]].append(result) - - assert len(results) == 2 - assert set(results.keys()) == { - "recombinant_example_1_0", - "recombinant_example_1_1", - } - - assert len(results["recombinant_example_1_0"]) == 2 - assert len(results["recombinant_example_1_1"]) == 2 - - class TestValidate: @pytest.mark.parametrize("date", ["2020-01-01", "2020-02-11"]) diff --git a/tests/test_inference.py b/tests/test_inference.py index c24bb51..f7d8aab 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -78,10 +78,48 @@ def test_get_recombinant_strains_ex1(self, fx_recombinant_example_1): d = sc2ts.get_recombinant_strains(fx_recombinant_example_1) assert d == {55: ["recombinant_example_1_0", "recombinant_example_1_1"]} - @pytest.mark.skip("Example broken by dataset") def test_get_recombinant_strains_ex2(self, fx_recombinant_example_2): d = sc2ts.get_recombinant_strains(fx_recombinant_example_2) - assert d == {56: ["recombinant"]} + assert d == {56: ["recombinant_114:29825"]} + + def test_recombinant_example_1(self, fx_recombinant_example_1): + ts = fx_recombinant_example_1 + samples_strain = ts.metadata["sc2ts"]["samples_strain"] + samples = ts.samples() + for s in ["recombinant_example_1_0", "recombinant_example_1_1"]: + u = samples[samples_strain.index(s)] + node = ts.node(u) + md = node.metadata["sc2ts"] + assert md["breakpoint_intervals"] == [[3788, 11083]] + assert md["hmm_match"]["path"] == [ + {"left": 0, "parent": 31, "right": 11083}, + {"left": 11083, "parent": 46, "right": 29904}, + ] + + def test_recombinant_example_2(self, fx_recombinant_example_2): + ts = fx_recombinant_example_2 + samples_strain = ts.metadata["sc2ts"]["samples_strain"] + u = ts.samples()[samples_strain.index("recombinant_114:29825")] + node = ts.node(u) + md = node.metadata["sc2ts"] + assert md["breakpoint_intervals"] == [[114, 29825]] + assert md["hmm_match"]["path"] == [ + {"left": 0, "parent": 53, "right": 29825}, + {"left": 29825, "parent": 54, "right": 29904}, + ] + + def test_recombinant_example_3(self, fx_recombinant_example_3): + ts = fx_recombinant_example_3 + samples_strain = ts.metadata["sc2ts"]["samples_strain"] + u = ts.samples()[samples_strain.index("recombinant_114:15001:15010:29825")] + node = ts.node(u) + md = node.metadata["sc2ts"] + assert md["breakpoint_intervals"] == [[114, 15001], [15010, 29825]] + assert md["hmm_match"]["path"] == [ + {"left": 0, "parent": 53, "right": 15001}, + {"left": 15001, "parent": 54, "right": 29825}, + {"left": 29825, "parent": 55, "right": 29904}, + ] class TestSolveNumMismatches: @@ -984,7 +1022,6 @@ def test_recombinant_example_1(self, fx_ts_map, fx_recombinant_example_1): {"left": bp, "parent": right_parent, "right": 29904}, ], } - assert smd["hmm_reruns"] == {} sample = ts.node(ts.samples()[-1]) smd = sample.metadata["sc2ts"] @@ -998,7 +1035,6 @@ def test_recombinant_example_1(self, fx_ts_map, fx_recombinant_example_1): {"left": bp, "parent": right_parent, "right": 29904}, ], } - assert smd["hmm_reruns"] == {} recomb_node = ts.node(ts.num_nodes - 1) assert recomb_node.flags == sc2ts.NODE_IS_RECOMBINANT @@ -1035,13 +1071,12 @@ def test_recombinant_example_1(self, fx_ts_map, fx_recombinant_example_1): assert row.parents == 2 assert row.causal_pango == {"Unknown": 2} - @pytest.mark.skip("Example broken by dataset") def test_recombinant_example_2(self, fx_ts_map, fx_recombinant_example_2): base_ts = fx_ts_map["2020-02-13"] date = "2020-03-01" rts = fx_recombinant_example_2 samples_strain = rts.metadata["sc2ts"]["samples_strain"] - assert samples_strain[-3:] == ["left", "right", "recombinant"] + assert samples_strain[-3:] == ["left", "right", "recombinant_114:29825"] sample = rts.node(rts.samples()[-1]) smd = sample.metadata["sc2ts"] @@ -1053,8 +1088,6 @@ def test_recombinant_example_2(self, fx_ts_map, fx_recombinant_example_2): ], } - assert smd["hmm_reruns"] == {} - def test_all_As(self, tmp_path, fx_ts_map, fx_dataset): # Same as the recombinant_example_1() function above # Just to get something that looks like an alignment easily @@ -1189,21 +1222,23 @@ def test_match_recombinant(self, fx_ts_map): assert m.path[1].right == ts.sequence_length -class TestMatchRecombinants: +class TestCharacteriseRecombinants: + def test_example_1(self, fx_ts_map): ts, s = recombinant_example_1(fx_ts_map) - sc2ts.match_recombinants( - samples=[s], - base_ts=ts, - num_mismatches=2, - num_threads=0, - ) + interval_left = 3788 + interval_right = 11083 left_parent = 31 right_parent = 46 - interval_right = 11083 - m = s.hmm_reruns["forward"] + sc2ts.match_tsinfer( + samples=[s], + ts=ts, + num_mismatches=2, + mismatch_threshold=10, + ) + m = s.hmm_match assert len(m.mutations) == 0 assert len(m.path) == 2 assert m.path[0].parent == left_parent @@ -1213,45 +1248,61 @@ def test_example_1(self, fx_ts_map): assert m.path[1].left == interval_right assert m.path[1].right == ts.sequence_length - interval_left = 3788 - m = s.hmm_reruns["reverse"] + sc2ts.characterise_recombinants(ts, [s]) + assert s.breakpoint_intervals == [(interval_left, interval_right)] + + sc2ts.match_tsinfer( + samples=[s], + ts=ts, + num_mismatches=2, + mismatch_threshold=10, + mirror_coordinates=True, + ) + m = s.hmm_match assert len(m.mutations) == 0 assert len(m.path) == 2 + assert m.path[0].parent == left_parent assert m.path[0].left == 0 assert m.path[0].right == interval_left - assert m.path[0].parent == left_parent assert m.path[1].parent == right_parent assert m.path[1].left == interval_left assert m.path[1].right == ts.sequence_length - m = s.hmm_reruns["no_recombination"] - # It seems that we can choose either the left or right parent - # arbitrarily :shrug: - assert len(m.mutations) == 3 - assert m.mutation_summary() in [ - "[A871G, A3027G, C3787T]", - "[T11083G, C15324T, C29303T]", - ] - assert len(m.path) == 1 - assert m.path[0].parent in [left_parent, right_parent] - assert m.path[0].left == 0 - assert m.path[0].right == ts.sequence_length - - assert "no_recombination" in s.summary() - - def test_all_As(self, fx_ts_map): - ts = fx_ts_map["2020-02-13"] - h = np.zeros(ts.num_sites, dtype=np.int8) - s = sc2ts.Sample("zerotype", "2020-02-14", haplotype=h) - - sc2ts.match_recombinants( + def test_example_3(self, fx_recombinant_example_3): + ts = fx_recombinant_example_3 + strains = ts.metadata["sc2ts"]["samples_strain"] + assert strains[-1].startswith("recomb") + u = ts.samples()[-1] + h = ts.genotype_matrix(samples=[u], alleles=tuple(sc2ts.IUPAC_ALLELES)).T[0] + tables = ts.dump_tables() + keep_edges = ts.edges_child < u + tables.edges.keep_rows(keep_edges) + keep_nodes = np.ones(ts.num_nodes, dtype=bool) + tables.nodes[u] = tables.nodes[u].replace(flags=0) + tables.sort() + base_ts = tables.tree_sequence() + + s = sc2ts.Sample("3way", "2020-02-14", haplotype=h.astype(np.int8)) + sc2ts.match_tsinfer( samples=[s], - base_ts=ts, - num_mismatches=3, - num_threads=0, + ts=base_ts, + num_mismatches=2, + mismatch_threshold=10, + mirror_coordinates=False, ) - assert len(s.hmm_reruns) == 3 - num_mutations = [] - for hmm_match in s.hmm_reruns.values(): - assert len(hmm_match.path) == 1 - assert len(hmm_match.mutations) == 20943 + sc2ts.characterise_recombinants(ts, [s]) + m = s.hmm_match + assert m.parents == [53, 54, 55] + assert m.breakpoints == [0, 15001, 29825, 29904] + assert s.breakpoint_intervals == [(114, 15001), (15010, 29825)] + # Verify that these breakpoints correspond to the reverse-direction HMM + sc2ts.match_tsinfer( + samples=[s], + ts=base_ts, + num_mismatches=2, + mismatch_threshold=10, + mirror_coordinates=True, + ) + m = s.hmm_match + assert m.parents == [53, 54, 55] + assert m.breakpoints == [0, 114, 15010, 29904] diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 3b0eb91..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,18 +0,0 @@ -import numpy as np -import pytest - -import sc2ts -import sc2ts.utils as utils -import util - - -class TestPadSites: - def check_site_padding(self, ts): - ts = sc2ts.pad_sites(ts) - ref = sc2ts.core.get_reference_sequence(as_array=True) - assert ts.num_sites == len(ref) - 1 - ancestral_state = ts.tables.sites.ancestral_state.view("S1").astype(str) - assert np.all(ancestral_state == ref[1:]) - - def test_initial(self): - self.check_site_padding(sc2ts.initial_ts())