Skip to content

Commit

Permalink
Add method to characterise recombinants
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Dec 13, 2024
1 parent 80df1c5 commit 00b332d
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 5 deletions.
35 changes: 33 additions & 2 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,9 @@ class Sample:
alignment_composition: Dict = None
haplotype: List = None
hmm_match: HmmMatch = None
hmm_reruns: Dict = dataclasses.field(default_factory=dict)
breakpoint_intervals: List = dataclasses.field(default_factory=list)
flags: int = tskit.NODE_IS_SAMPLE
hmm_reruns: Dict = dataclasses.field(default_factory=dict)

@property
def is_recombinant(self):
Expand Down Expand Up @@ -725,8 +726,9 @@ def _extend(
num_threads=num_threads,
memory_limit=memory_limit,
)

characterise_match_mutations(base_ts, samples)
characterise_recombinants(base_ts, samples)

for sample in unconditional_include_samples:
# We want this sample to included unconditionally, so we set the
# hmm cost to 0 < hmm_cost < hmm_cost_threshold. We use 0.5
Expand Down Expand Up @@ -840,6 +842,8 @@ def add_sample_to_tables(sample, tables, group_id=None):
"alignment_composition": dict(sample.alignment_composition),
"num_missing_sites": sample.num_missing_sites,
}
if sample.is_recombinant:
sc2ts_md["breakpoint_intervals"] = sample.breakpoint_intervals
if group_id is not None:
sc2ts_md["group_id"] = group_id
metadata = {**sample.metadata, "sc2ts": sc2ts_md}
Expand Down Expand Up @@ -1681,6 +1685,33 @@ def get_closest_mutation(node, site_id):
logger.debug(f"Characterised {num_mutations} mutations")


def characterise_recombinants(ts, samples):
"""
Update the metadata for any recombinants to add interval information to the metadata.
"""
recombinants = [s for s in samples if s.is_recombinant]
if len(recombinants) == 0:
return
logger.info(f"Characterising {len(recombinants)} recombinants")

# NOTE: could make this more efficient by doing one call to genotype_matrix,
# but recombinants are rare so let's keep this simple
for s in recombinants:
parents = [seg.parent for seg in s.hmm_match.path]
# Can't have missing data here, so we're OK.
H = ts.genotype_matrix(samples=parents, isolated_as_missing=False).T
breakpoint_intervals = []
for j in range(len(parents) - 1):
parents_differ = np.where(H[j] != H[j + 1])[0]
pos = ts.sites_position[parents_differ].astype(int)
right = s.hmm_match.path[j].right
right_index = np.searchsorted(pos, right)
assert pos[right_index] == right
left = pos[right_index - 1] + 1
breakpoint_intervals.append((int(left), int(right)))
s.breakpoint_intervals = breakpoint_intervals


def attach_tree(
parent_ts,
parent_tables,
Expand Down
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,11 @@ def recombinant_example_2(tmp_path, fx_ts_map, fx_dataset, ds_path):
ts_path = tmp_path / "intermediate.ts"
ts.dump(ts_path)

# Now run again with the recombinant of these two
# Now run again with the recombinant of these two, encoding the interval in the # name
date = "2020-03-02"
ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", {"recombinant": a}, date=date)
left = start + 3 + 1
right = end - 3 + 1
ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", {f"recombinant_{left}:{right}": a}, date=date)
rts = sc2ts.extend(
dataset=ds.path,
base_ts=ts_path,
Expand Down
75 changes: 74 additions & 1 deletion tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,33 @@ def test_get_recombinant_strains_ex1(self, fx_recombinant_example_1):

def test_get_recombinant_strains_ex2(self, fx_recombinant_example_2):
d = sc2ts.get_recombinant_strains(fx_recombinant_example_2)
assert d == {56: ["recombinant"]}
assert d == {56: ["recombinant_114:29825"]}

def test_recombinant_example_1(self, fx_recombinant_example_1):
ts = fx_recombinant_example_1
samples_strain = ts.metadata["sc2ts"]["samples_strain"]
samples = ts.samples()
for s in ["recombinant_example_1_0", "recombinant_example_1_1"]:
u = samples[samples_strain.index(s)]
node = ts.node(u)
md = node.metadata["sc2ts"]
assert md["breakpoint_intervals"] == [[3788, 11083]]
assert md["hmm_match"]["path"] == [
{"left": 0, "parent": 31, "right": 11083},
{"left": 11083, "parent": 46, "right": 29904},
]

def test_recombinant_example_2(self, fx_recombinant_example_2):
ts = fx_recombinant_example_2
samples_strain = ts.metadata["sc2ts"]["samples_strain"]
u = ts.samples()[samples_strain.index("recombinant_114:29825")]
node = ts.node(u)
md = node.metadata["sc2ts"]
assert md["breakpoint_intervals"] == [[114, 29825]]
assert md["hmm_match"]["path"] == [
{"left": 0, "parent": 53, "right": 29825},
{"left": 29825, "parent": 54, "right": 29904},
]


class TestSolveNumMismatches:
Expand Down Expand Up @@ -1187,6 +1213,53 @@ def test_match_recombinant(self, fx_ts_map):
assert m.path[1].right == ts.sequence_length


class TestCharacteriseRecombinants:

def test_example_1(self, fx_ts_map):
ts, s = recombinant_example_1(fx_ts_map)

interval_left = 3788
interval_right = 11083
left_parent = 31
right_parent = 46

sc2ts.match_tsinfer(
samples=[s],
ts=ts,
num_mismatches=2,
mismatch_threshold=10,
)
m = s.hmm_match
assert len(m.mutations) == 0
assert len(m.path) == 2
assert m.path[0].parent == left_parent
assert m.path[0].left == 0
assert m.path[0].right == interval_right
assert m.path[1].parent == right_parent
assert m.path[1].left == interval_right
assert m.path[1].right == ts.sequence_length

sc2ts.characterise_recombinants(ts, [s])
assert s.breakpoint_intervals == [(interval_left, interval_right)]

sc2ts.match_tsinfer(
samples=[s],
ts=ts,
num_mismatches=2,
mismatch_threshold=10,
mirror_coordinates=True,
)
m = s.hmm_match
assert len(m.mutations) == 0
assert len(m.path) == 2
assert m.path[0].parent == left_parent
assert m.path[0].left == 0
assert m.path[0].right == interval_left
assert m.path[1].parent == right_parent
assert m.path[1].left == interval_left
assert m.path[1].right == ts.sequence_length


class TestMatchRecombinants:
def test_example_1(self, fx_ts_map):
ts, s = recombinant_example_1(fx_ts_map)
Expand Down

0 comments on commit 00b332d

Please sign in to comment.