Skip to content

Commit

Permalink
Update tests for haploid case
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 21, 2024
1 parent 5583c89 commit 7c1c600
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,11 +927,10 @@ def verify(self, ts):
# Ensure that the decoded matrices are the same
flipped_H = np.flip(H, axis=0)
flipped_s = np.flip(s, axis=1)
num_alleles = ls.core.get_num_alleles(flipped_H, flipped_s)
F_mirror_matrix, c, ll = ls.forwards(
reference_panel=flipped_H,
query=flipped_s,
num_alleles=num_alleles,
ploidy=1,
prob_recombination=r_flip,
prob_mutation=np.flip(mu),
scale_mutation_rate=False,
Expand All @@ -953,11 +952,10 @@ def verify(self, ts):
# Warning from lshmm:
# Passed a vector of mutation rates, but rescaling each mutation
# rate conditional on the number of alleles
num_alleles = ls.core.get_num_alleles(H, s)
F, c, ll = ls.forwards(
reference_panel=H,
query=s,
num_alleles=num_alleles,
ploidy=1,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=scale_mutation,
Expand All @@ -983,19 +981,18 @@ class TestForwardBackwardTree(FBAlgorithmBase):

def verify(self, ts):
for n, H, s, r, mu in self.example_parameters_haplotypes(ts):
num_alleles = ls.core.get_num_alleles(H, s)
F, c, ll = ls.forwards(
reference_panel=H,
query=s,
num_alleles=num_alleles,
ploidy=1,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
B = ls.backwards(
reference_panel=H,
query=s,
num_alleles=num_alleles,
ploidy=1,
normalisation_factor_from_forward=c,
prob_recombination=r,
prob_mutation=mu,
Expand Down Expand Up @@ -1030,11 +1027,10 @@ class TestTreeViterbiHap(VitAlgorithmBase):

def verify(self, ts):
for n, H, s, r, mu in self.example_parameters_haplotypes(ts):
num_alleles = ls.core.get_num_alleles(H, s)
path, ll = ls.viterbi(
reference_panel=H,
query=s,
num_alleles=num_alleles,
ploidy=1,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
Expand All @@ -1050,7 +1046,7 @@ def verify(self, ts):
ll_check = ls.path_loglik(
reference_panel=H,
query=s,
num_alleles=num_alleles,
ploidy=1,
path=path_tree,
prob_recombination=r,
prob_mutation=mu,
Expand All @@ -1072,12 +1068,11 @@ def check_viterbi(ts, h, recombination=None, mutation=None):

G = ts.genotype_matrix()
s = h.reshape(1, m)
num_alleles = ls.core.get_num_alleles(G, s)

path, ll = ls.viterbi(
reference_panel=G,
query=s,
num_alleles=num_alleles,
ploidy=1,
prob_recombination=recombination,
prob_mutation=mutation,
scale_mutation_rate=False,
Expand All @@ -1095,7 +1090,7 @@ def check_viterbi(ts, h, recombination=None, mutation=None):
ll_check = ls.path_loglik(
reference_panel=G,
query=s,
num_alleles=num_alleles,
ploidy=1,
path=path_tree,
prob_recombination=recombination,
prob_mutation=mutation,
Expand Down Expand Up @@ -1131,12 +1126,11 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None):

G = ts.genotype_matrix()
s = h.reshape(1, m)
num_alleles = ls.core.get_num_alleles(G, s)

F, c, ll = ls.forwards(
reference_panel=G,
query=s,
num_alleles=num_alleles,
ploidy=1,
prob_recombination=recombination,
prob_mutation=mutation,
scale_mutation_rate=False,
Expand Down Expand Up @@ -1179,12 +1173,11 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None):

G = ts.genotype_matrix()
s = h.reshape(1, m)
num_alleles = ls.core.get_num_alleles(G, s)

B = ls.backwards(
reference_panel=G,
query=h.reshape(1, m),
num_alleles=num_alleles,
ploidy=1,
normalisation_factor_from_forward=forward_cm.normalisation_factor,
prob_recombination=recombination,
prob_mutation=mutation,
Expand Down

0 comments on commit 7c1c600

Please sign in to comment.