diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index dcc1d684fb..7b0786663e 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2023 Tskit Developers +# Copyright (c) 2019-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -484,6 +484,30 @@ def run(self, h, normalisation_factor): return self.output +class ForwardBackwardAlgorithm(BackwardAlgorithm): + """ + The Li and Stephens forward-backward algorithm. + + Alternatively, it is more efficient to compute this + on the fly during a run of the backward algorithm. + """ + + def compute_next_probability(self, site_id, p_next, is_match, node): + p_e = self.compute_emission_proba(site_id, is_match) + bwd_prob = p_next * p_e + # Get forward prob. from compressed matrix. + fwd_prob = 0 + return bwd_prob * fwd_prob + + def run(self, h, normalisation_factor): + self.initialise(value=1) + while self.tree.prev(): + self.update_tree(direction=tskit.REVERSE) + for site in reversed(list(self.tree.sites())): + self.process_site(site, h[site.id], normalisation_factor[site.id]) + return self.output + + class ViterbiAlgorithm(LsHmmAlgorithm): """ Runs the Li and Stephens Viterbi algorithm. @@ -729,6 +753,40 @@ def ls_backward_tree(h, ts, rho, mu, normalisation_factor, precision=30, alleles return ba.run(h, normalisation_factor) +def ls_fb_tree( + h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False +): + alleles, n_alleles = get_site_alleles(ts, h, alleles) + fa = ForwardAlgorithm( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation_based_on_n_alleles, + ) + forward_cm = fa.run(h) + ba = BackwardAlgorithm( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + ) + _ = ba.run(h, forward_cm.normalisation_factor) + fba = ForwardBackwardAlgorithm( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + ) + return fba.run(h) + + def ls_viterbi_tree( h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False ): @@ -1184,6 +1242,10 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): nt.assert_allclose(B, B_lib) +def check_fb_matrix(): + raise NotImplementedError() + + def add_unique_sample_mutations(ts, start=0): """ Adds a mutation for each of the samples at equally spaced locations