Skip to content

Commit

Permalink
Merge pull request #448 from jeromekelleher/recombinant-tweaks
Browse files Browse the repository at this point in the history
Recombinant tweaks
  • Loading branch information
jeromekelleher authored Dec 16, 2024
2 parents d243b52 + b357d4d commit 0a43c33
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 295 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,7 @@ packages = ["sc2ts"]

[tool.setuptools_scm]
write_to = "sc2ts/_version.py"

[tool.pytest.ini_options]
testpaths = "tests"
addopts = "--cov=sc2ts --cov-report term-missing"
98 changes: 0 additions & 98 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,103 +633,6 @@ def find_previous_date_path(date, path_pattern):
return path


@click.command()
@click.argument("dataset", type=click.Path(exists=True, dir_okay=False))
@click.argument("ts", type=click.Path(exists=True, dir_okay=False))
@click.argument("path_pattern")
@num_mismatches
@click.option(
"--num-threads",
default=0,
type=int,
help="Number of match threads (default to one)",
)
@click.option("--progress/--no-progress", default=True)
@click.option("-v", "--verbose", count=True)
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
def rematch_recombinants(
dataset,
ts,
path_pattern,
num_mismatches,
num_threads,
progress,
verbose,
log_file,
):
setup_logging(verbose, log_file)
ts = tszip.load(ts)
# This is a map of recombinant node to the samples involved in
# the original causal sample group.
recombinant_strains = sc2ts.get_recombinant_strains(ts)
logger.info(
f"Got {len(recombinant_strains)} recombinants and "
f"{sum(len(v) for v in recombinant_strains.values())} strains"
)

# Map recombinants to originating date
recombinant_to_path = {}
strain_to_recombinant = {}
all_strains = []
for u, strains in recombinant_strains.items():
date_added = ts.node(u).metadata["sc2ts"]["date_added"]
base_ts_path = find_previous_date_path(date_added, path_pattern)
recombinant_to_path[u] = base_ts_path
for strain in strains:
strain_to_recombinant[strain] = u
all_strains.append(strain)

ds = sc2ts.Dataset(dataset)
progress_title = "Recomb"
samples = sc2ts.preprocess(
all_strains,
datset=ds,
show_progress=progress,
progress_title=progress_title,
keep_sites=ts.sites_position.astype(int),
num_workers=num_threads,
)

recombinant_to_samples = collections.defaultdict(list)
for sample in samples:
if sample.haplotype is None:
raise ValueError(f"No alignment stored for {sample.strain}")
recombinant = strain_to_recombinant[sample.strain]
recombinant_to_samples[recombinant].append(sample)

work = []
for recombinant, samples in recombinant_to_samples.items():
for direction in ["forward", "reverse"]:
work.append(
MatchWork(
recombinant_to_path[recombinant],
samples,
num_mismatches=num_mismatches,
direction=direction,
)
)

bar = sc2ts.get_progress(None, progress_title, "HMM", progress, total=len(work))

def output(hmm_runs):
bar.update()
for run in hmm_runs:
print(run.asjson())

results = []
if num_threads == 0:
for w in work:
hmm_runs = _match_worker(w)
output(hmm_runs)
else:
with cf.ProcessPoolExecutor(num_threads) as executor:
futures = [executor.submit(_match_worker, w) for w in work]
for future in cf.as_completed(futures):
hmm_runs = future.result()
output(hmm_runs)
bar.close()


@click.version_option(core.__version__)
@click.group()
def cli():
Expand All @@ -747,5 +650,4 @@ def cli():
cli.add_command(infer)
cli.add_command(validate)
cli.add_command(_match)
cli.add_command(rematch_recombinants)
cli.add_command(tally_lineages)
78 changes: 34 additions & 44 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ class Sample:
alignment_composition: Dict = None
haplotype: List = None
hmm_match: HmmMatch = None
hmm_reruns: Dict = dataclasses.field(default_factory=dict)
breakpoint_intervals: List = dataclasses.field(default_factory=list)
flags: int = tskit.NODE_IS_SAMPLE

@property
Expand All @@ -381,46 +381,7 @@ def num_deletion_sites(self):

def summary(self):
hmm_match = "No match" if self.hmm_match is None else self.hmm_match.summary()
s = f"{self.strain} {self.date} {self.pango} {hmm_match}"
for name, hmm_match in self.hmm_reruns.items():
s += f"; {name}: {hmm_match.summary()}"
return s


# TODO not clear if we still need this as mirroring is done differently now.
# Remove if we don't have any issues with running the HMM in reverse
def pad_sites(ts):
"""
Fill in missing sites with the reference state.
"""
ref = core.get_reference_sequence()
missing_sites = set(np.arange(1, len(ref)))
missing_sites -= set(ts.sites_position.astype(int))
tables = ts.dump_tables()
for pos in missing_sites:
tables.sites.add_row(pos, ref[pos])
tables.sort()
return tables.tree_sequence()


# TODO remove this
def match_recombinants(
samples, base_ts, num_mismatches, show_progress=False, num_threads=None
):
for hmm_pass in ["forward", "reverse", "no_recombination"]:
logger.info(f"Running {hmm_pass} pass for {len(samples)} recombinants")
match_tsinfer(
samples=samples,
ts=base_ts,
num_mismatches=1000 if hmm_pass == "no_recombination" else num_mismatches,
mismatch_threshold=100,
num_threads=num_threads,
show_progress=show_progress,
mirror_coordinates=hmm_pass == "reverse",
)

for sample in samples:
sample.hmm_reruns[hmm_pass] = sample.hmm_match
return f"{self.strain} {self.date} {self.pango} {hmm_match}"


def match_samples(
Expand Down Expand Up @@ -725,8 +686,9 @@ def _extend(
num_threads=num_threads,
memory_limit=memory_limit,
)

characterise_match_mutations(base_ts, samples)
characterise_recombinants(base_ts, samples)

for sample in unconditional_include_samples:
# We want this sample to included unconditionally, so we set the
# hmm cost to 0 < hmm_cost < hmm_cost_threshold. We use 0.5
Expand Down Expand Up @@ -836,10 +798,11 @@ def update_top_level_metadata(ts, date, retro_groups, samples):
def add_sample_to_tables(sample, tables, group_id=None):
sc2ts_md = {
"hmm_match": sample.hmm_match.asdict(),
"hmm_reruns": {k: m.asdict() for k, m in sample.hmm_reruns.items()},
"alignment_composition": dict(sample.alignment_composition),
"num_missing_sites": sample.num_missing_sites,
}
if sample.is_recombinant:
sc2ts_md["breakpoint_intervals"] = sample.breakpoint_intervals
if group_id is not None:
sc2ts_md["group_id"] = group_id
metadata = {**sample.metadata, "sc2ts": sc2ts_md}
Expand Down Expand Up @@ -1296,7 +1259,7 @@ def make_tsb(ts, num_alleles, mirror_coordinates=False):
ts.edges_parent[index],
ts.edges_child[index],
)
assert tsb.num_match_nodes == ts.num_nodes
# assert tsb.num_match_nodes == ts.num_nodes

tsb.restore_mutations(
ts.mutations_site, ts.mutations_node, derived_state, ts.mutations_parent
Expand Down Expand Up @@ -1681,6 +1644,33 @@ def get_closest_mutation(node, site_id):
logger.debug(f"Characterised {num_mutations} mutations")


def characterise_recombinants(ts, samples):
"""
Update the metadata for any recombinants to add interval information to the metadata.
"""
recombinants = [s for s in samples if s.is_recombinant]
if len(recombinants) == 0:
return
logger.info(f"Characterising {len(recombinants)} recombinants")

# NOTE: could make this more efficient by doing one call to genotype_matrix,
# but recombinants are rare so let's keep this simple
for s in recombinants:
parents = [seg.parent for seg in s.hmm_match.path]
# Can't have missing data here, so we're OK.
H = ts.genotype_matrix(samples=parents, isolated_as_missing=False).T
breakpoint_intervals = []
for j in range(len(parents) - 1):
parents_differ = np.where(H[j] != H[j + 1])[0]
pos = ts.sites_position[parents_differ].astype(int)
right = s.hmm_match.path[j].right
right_index = np.searchsorted(pos, right)
assert pos[right_index] == right
left = pos[right_index - 1] + 1
breakpoint_intervals.append((int(left), int(right)))
s.breakpoint_intervals = breakpoint_intervals


def attach_tree(
parent_ts,
parent_tables,
Expand Down
Loading

0 comments on commit 0a43c33

Please sign in to comment.