Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Mar 4, 2024
1 parent 8b1be4f commit 9b31dd8
Showing 1 changed file with 63 additions and 1 deletion.
64 changes: 63 additions & 1 deletion python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2019-2023 Tskit Developers
# Copyright (c) 2019-2024 Tskit Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -484,6 +484,30 @@ def run(self, h, normalisation_factor):
return self.output


class ForwardBackwardAlgorithm(BackwardAlgorithm):
"""
The Li and Stephens forward-backward algorithm.
Alternatively, it is more efficient to compute this
on the fly during a run of the backward algorithm.
"""

def compute_next_probability(self, site_id, p_next, is_match, node):
p_e = self.compute_emission_proba(site_id, is_match)
bwd_prob = p_next * p_e
# Get forward prob. from compressed matrix.
fwd_prob = 0
return bwd_prob * fwd_prob

def run(self, h, normalisation_factor):
self.initialise(value=1)
while self.tree.prev():
self.update_tree(direction=tskit.REVERSE)
for site in reversed(list(self.tree.sites())):
self.process_site(site, h[site.id], normalisation_factor[site.id])
return self.output


class ViterbiAlgorithm(LsHmmAlgorithm):
"""
Runs the Li and Stephens Viterbi algorithm.
Expand Down Expand Up @@ -729,6 +753,40 @@ def ls_backward_tree(h, ts, rho, mu, normalisation_factor, precision=30, alleles
return ba.run(h, normalisation_factor)


def ls_fb_tree(
h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False
):
alleles, n_alleles = get_site_alleles(ts, h, alleles)
fa = ForwardAlgorithm(
ts,
rho,
mu,
alleles,
n_alleles,
precision=precision,
scale_mutation=scale_mutation_based_on_n_alleles,
)
forward_cm = fa.run(h)
ba = BackwardAlgorithm(
ts,
rho,
mu,
alleles,
n_alleles,
precision=precision,
)
_ = ba.run(h, forward_cm.normalisation_factor)
fba = ForwardBackwardAlgorithm(
ts,
rho,
mu,
alleles,
n_alleles,
precision=precision,
)
return fba.run(h)


def ls_viterbi_tree(
h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False
):
Expand Down Expand Up @@ -1184,6 +1242,10 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None):
nt.assert_allclose(B, B_lib)


def check_fb_matrix():
raise NotImplementedError()


def add_unique_sample_mutations(ts, start=0):
"""
Adds a mutation for each of the samples at equally spaced locations
Expand Down

0 comments on commit 9b31dd8

Please sign in to comment.