Skip to content

Commit

Permalink
Remove hmm_reruns and other infrastructure for remathing recombinants
Browse files Browse the repository at this point in the history
Closes #159
  • Loading branch information
jeromekelleher committed Dec 13, 2024
1 parent 28b1cb0 commit a656944
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 213 deletions.
98 changes: 0 additions & 98 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,103 +633,6 @@ def find_previous_date_path(date, path_pattern):
return path


@click.command()
@click.argument("dataset", type=click.Path(exists=True, dir_okay=False))
@click.argument("ts", type=click.Path(exists=True, dir_okay=False))
@click.argument("path_pattern")
@num_mismatches
@click.option(
"--num-threads",
default=0,
type=int,
help="Number of match threads (default to one)",
)
@click.option("--progress/--no-progress", default=True)
@click.option("-v", "--verbose", count=True)
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
def rematch_recombinants(
dataset,
ts,
path_pattern,
num_mismatches,
num_threads,
progress,
verbose,
log_file,
):
setup_logging(verbose, log_file)
ts = tszip.load(ts)
# This is a map of recombinant node to the samples involved in
# the original causal sample group.
recombinant_strains = sc2ts.get_recombinant_strains(ts)
logger.info(
f"Got {len(recombinant_strains)} recombinants and "
f"{sum(len(v) for v in recombinant_strains.values())} strains"
)

# Map recombinants to originating date
recombinant_to_path = {}
strain_to_recombinant = {}
all_strains = []
for u, strains in recombinant_strains.items():
date_added = ts.node(u).metadata["sc2ts"]["date_added"]
base_ts_path = find_previous_date_path(date_added, path_pattern)
recombinant_to_path[u] = base_ts_path
for strain in strains:
strain_to_recombinant[strain] = u
all_strains.append(strain)

ds = sc2ts.Dataset(dataset)
progress_title = "Recomb"
samples = sc2ts.preprocess(
all_strains,
datset=ds,
show_progress=progress,
progress_title=progress_title,
keep_sites=ts.sites_position.astype(int),
num_workers=num_threads,
)

recombinant_to_samples = collections.defaultdict(list)
for sample in samples:
if sample.haplotype is None:
raise ValueError(f"No alignment stored for {sample.strain}")
recombinant = strain_to_recombinant[sample.strain]
recombinant_to_samples[recombinant].append(sample)

work = []
for recombinant, samples in recombinant_to_samples.items():
for direction in ["forward", "reverse"]:
work.append(
MatchWork(
recombinant_to_path[recombinant],
samples,
num_mismatches=num_mismatches,
direction=direction,
)
)

bar = sc2ts.get_progress(None, progress_title, "HMM", progress, total=len(work))

def output(hmm_runs):
bar.update()
for run in hmm_runs:
print(run.asjson())

results = []
if num_threads == 0:
for w in work:
hmm_runs = _match_worker(w)
output(hmm_runs)
else:
with cf.ProcessPoolExecutor(num_threads) as executor:
futures = [executor.submit(_match_worker, w) for w in work]
for future in cf.as_completed(futures):
hmm_runs = future.result()
output(hmm_runs)
bar.close()


@click.version_option(core.__version__)
@click.group()
def cli():
Expand All @@ -747,5 +650,4 @@ def cli():
cli.add_command(infer)
cli.add_command(validate)
cli.add_command(_match)
cli.add_command(rematch_recombinants)
cli.add_command(tally_lineages)
43 changes: 1 addition & 42 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ class Sample:
hmm_match: HmmMatch = None
breakpoint_intervals: List = dataclasses.field(default_factory=list)
flags: int = tskit.NODE_IS_SAMPLE
hmm_reruns: Dict = dataclasses.field(default_factory=dict)

@property
def is_recombinant(self):
Expand All @@ -382,46 +381,7 @@ def num_deletion_sites(self):

def summary(self):
hmm_match = "No match" if self.hmm_match is None else self.hmm_match.summary()
s = f"{self.strain} {self.date} {self.pango} {hmm_match}"
for name, hmm_match in self.hmm_reruns.items():
s += f"; {name}: {hmm_match.summary()}"
return s


# TODO not clear if we still need this as mirroring is done differently now.
# Remove if we don't have any issues with running the HMM in reverse
def pad_sites(ts):
"""
Fill in missing sites with the reference state.
"""
ref = core.get_reference_sequence()
missing_sites = set(np.arange(1, len(ref)))
missing_sites -= set(ts.sites_position.astype(int))
tables = ts.dump_tables()
for pos in missing_sites:
tables.sites.add_row(pos, ref[pos])
tables.sort()
return tables.tree_sequence()


# TODO remove this
def match_recombinants(
samples, base_ts, num_mismatches, show_progress=False, num_threads=None
):
for hmm_pass in ["forward", "reverse", "no_recombination"]:
logger.info(f"Running {hmm_pass} pass for {len(samples)} recombinants")
match_tsinfer(
samples=samples,
ts=base_ts,
num_mismatches=1000 if hmm_pass == "no_recombination" else num_mismatches,
mismatch_threshold=100,
num_threads=num_threads,
show_progress=show_progress,
mirror_coordinates=hmm_pass == "reverse",
)

for sample in samples:
sample.hmm_reruns[hmm_pass] = sample.hmm_match
return f"{self.strain} {self.date} {self.pango} {hmm_match}"


def match_samples(
Expand Down Expand Up @@ -838,7 +798,6 @@ def update_top_level_metadata(ts, date, retro_groups, samples):
def add_sample_to_tables(sample, tables, group_id=None):
sc2ts_md = {
"hmm_match": sample.hmm_match.asdict(),
"hmm_reruns": {k: m.asdict() for k, m in sample.hmm_reruns.items()},
"alignment_composition": dict(sample.alignment_composition),
"num_missing_sites": sample.num_missing_sites,
}
Expand Down
74 changes: 1 addition & 73 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,6 @@ def test_recombinant_example_1(self, fx_ts_map, fx_recombinant_example_1):
{"left": bp, "parent": right_parent, "right": 29904},
],
}
assert smd["hmm_reruns"] == {}

sample = ts.node(ts.samples()[-1])
smd = sample.metadata["sc2ts"]
Expand All @@ -1036,7 +1035,6 @@ def test_recombinant_example_1(self, fx_ts_map, fx_recombinant_example_1):
{"left": bp, "parent": right_parent, "right": 29904},
],
}
assert smd["hmm_reruns"] == {}

recomb_node = ts.node(ts.num_nodes - 1)
assert recomb_node.flags == sc2ts.NODE_IS_RECOMBINANT
Expand Down Expand Up @@ -1078,7 +1076,7 @@ def test_recombinant_example_2(self, fx_ts_map, fx_recombinant_example_2):
date = "2020-03-01"
rts = fx_recombinant_example_2
samples_strain = rts.metadata["sc2ts"]["samples_strain"]
assert samples_strain[-3:] == ["left", "right", "recombinant"]
assert samples_strain[-3:] == ["left", "right", "recombinant_114:29825"]

sample = rts.node(rts.samples()[-1])
smd = sample.metadata["sc2ts"]
Expand All @@ -1090,8 +1088,6 @@ def test_recombinant_example_2(self, fx_ts_map, fx_recombinant_example_2):
],
}

assert smd["hmm_reruns"] == {}

def test_all_As(self, tmp_path, fx_ts_map, fx_dataset):
# Same as the recombinant_example_1() function above
# Just to get something that looks like an alignment easily
Expand Down Expand Up @@ -1310,71 +1306,3 @@ def test_example_3(self, fx_recombinant_example_3):
m = s.hmm_match
assert m.parents == [53, 54, 55]
assert m.breakpoints == [0, 114, 15010, 29904]


class TestMatchRecombinants:
def test_example_1(self, fx_ts_map):
ts, s = recombinant_example_1(fx_ts_map)

sc2ts.match_recombinants(
samples=[s],
base_ts=ts,
num_mismatches=2,
num_threads=0,
)
left_parent = 31
right_parent = 46
interval_right = 11083

m = s.hmm_reruns["forward"]
assert len(m.mutations) == 0
assert len(m.path) == 2
assert m.path[0].parent == left_parent
assert m.path[0].left == 0
assert m.path[0].right == interval_right
assert m.path[1].parent == right_parent
assert m.path[1].left == interval_right
assert m.path[1].right == ts.sequence_length

interval_left = 3788
m = s.hmm_reruns["reverse"]
assert len(m.mutations) == 0
assert len(m.path) == 2
assert m.path[0].left == 0
assert m.path[0].right == interval_left
assert m.path[0].parent == left_parent
assert m.path[1].parent == right_parent
assert m.path[1].left == interval_left
assert m.path[1].right == ts.sequence_length

m = s.hmm_reruns["no_recombination"]
# It seems that we can choose either the left or right parent
# arbitrarily :shrug:
assert len(m.mutations) == 3
assert m.mutation_summary() in [
"[A871G, A3027G, C3787T]",
"[T11083G, C15324T, C29303T]",
]
assert len(m.path) == 1
assert m.path[0].parent in [left_parent, right_parent]
assert m.path[0].left == 0
assert m.path[0].right == ts.sequence_length

assert "no_recombination" in s.summary()

def test_all_As(self, fx_ts_map):
ts = fx_ts_map["2020-02-13"]
h = np.zeros(ts.num_sites, dtype=np.int8)
s = sc2ts.Sample("zerotype", "2020-02-14", haplotype=h)

sc2ts.match_recombinants(
samples=[s],
base_ts=ts,
num_mismatches=3,
num_threads=0,
)
assert len(s.hmm_reruns) == 3
num_mutations = []
for hmm_match in s.hmm_reruns.values():
assert len(hmm_match.path) == 1
assert len(hmm_match.mutations) == 20943

0 comments on commit a656944

Please sign in to comment.