Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 20, 2024
1 parent 93cb51a commit ec8ed22
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 208 deletions.
14 changes: 14 additions & 0 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class LSBase:
def verify(self, ts):
raise NotImplementedError()

def verify_larger(self, ts):
pass

def assertAllClose(self, A, B):
np.testing.assert_allclose(A, B, rtol=1e-9, atol=0.0)

Expand Down Expand Up @@ -287,6 +290,17 @@ def get_multiallelic_n16(self, seed=42):
assert ts.num_sites > 5
return ts

# Prepare a larger example dataset.
def get_large(self, n=50, length=1e5, mean_r=1e-5, mean_mu=1e-5, seed=42):
ts = msprime.simulate(
n + 1,
length=length,
mutation_rate=mean_mu,
recombination_rate=mean_r,
random_seed=seed,
)
return ts


class FBAlgorithmBase(LSBase):
"""Base for testing forwards-backwards algorithms."""
Expand Down
224 changes: 16 additions & 208 deletions tests/test_LS_haploid_diploid.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import itertools
import pytest

import numpy as np
import numba as nb

import msprime
import tskit

import lshmm.core as core
import lsbase
import lshmm.fb_diploid as fbd
import lshmm.fb_haploid as fbh
import lshmm.vit_diploid as vd
Expand All @@ -17,68 +13,12 @@
class LSBase:
"""Superclass of Li and Stephens tests."""

def example_haplotypes(self, ts):
H = ts.genotype_matrix()
s = H[:, 0].reshape(1, H.shape[0])
H = H[:, 1:]

haplotypes = [s, H[:, -1].reshape(1, H.shape[0])]
s_tmp = s.copy()
s_tmp[0, -1] = core.MISSING
haplotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, ts.num_sites // 2] = core.MISSING
haplotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, :] = core.MISSING
haplotypes.append(s_tmp)

return H, haplotypes

def haplotype_emission(self, mu, m):
e = np.zeros((m, 2))
e[:, 0] = mu
e[:, 1] = 1 - mu
return e

def genotype_emission(self, mu, m):
e = np.zeros((m, 8))
e[:, core.EQUAL_BOTH_HOM] = (1 - mu) ** 2
e[:, core.UNEQUAL_BOTH_HOM] = mu**2
e[:, core.BOTH_HET] = 1 - mu
e[:, core.REF_HOM_OBS_HET] = 2 * mu * (1 - mu)
e[:, core.REF_HET_OBS_HOM] = mu * (1 - mu)
e[:, core.MISSING_INDEX] = 1
return e

def example_parameters_haplotypes(self, ts, seed=42):
"""Returns an iterator over combinations of haplotype, recombination and mutation probabilities."""
np.random.seed(seed)
H, haplotypes = self.example_haplotypes(ts)
n = H.shape[1]
m = ts.get_num_sites()

# Here we have equal mutation and recombination
r = np.zeros(m) + 0.01
mu = np.zeros(m) + 0.01
r[0] = 0

e = self.haplotype_emission(mu, m)

for s in haplotypes:
yield n, m, H, s, e, r

# Mixture of random and extremes
rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)]
mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33]

e = self.haplotype_emission(mu, m)

for s, r, mu in itertools.product(haplotypes, rs, mus):
r[0] = 0
e = self.haplotype_emission(mu, m)
yield n, m, H, s, e, r

def example_parameters_haplotypes_larger(
self, ts, seed=42, mean_r=1e-5, mean_mu=1e-5
):
Expand All @@ -97,60 +37,6 @@ def example_parameters_haplotypes_larger(
for s in haplotypes:
yield n, m, H, s, e, r

def example_genotypes(self, ts, seed=42):
H = ts.genotype_matrix()
s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0])
H = H[:, 2:]

genotypes = [
s,
H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]),
]

s_tmp = s.copy()
s_tmp[0, -1] = core.MISSING
genotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, ts.num_sites // 2] = core.MISSING
genotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, :] = core.MISSING
genotypes.append(s_tmp)

m = ts.get_num_sites()
n = H.shape[1]

G = np.zeros((m, n, n))
for i in range(m):
G[i, :, :] = np.add.outer(H[i, :], H[i, :])

return H, G, genotypes

def example_parameters_genotypes(self, ts, seed=42):
np.random.seed(seed)
H, G, genotypes = self.example_genotypes(ts)
n = H.shape[1]
m = ts.get_num_sites()

# Here we have equal mutation and recombination
r = np.zeros(m) + 0.01
mu = np.zeros(m) + 0.01
r[0] = 0

e = self.genotype_emission(mu, m)

for s in genotypes:
yield n, m, G, s, e, r

# Mixture of random and extremes
rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)]
mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33]

for s, r, mu in itertools.product(genotypes, rs, mus):
r[0] = 0
e = self.genotype_emission(mu, m)
yield n, m, G, s, e, r

def example_parameters_genotypes_larger(
self, ts, seed=42, mean_r=1e-5, mean_mu=1e-5
):
Expand All @@ -170,63 +56,8 @@ def example_parameters_genotypes_larger(
for s in genotypes:
yield n, m, G, s, e, r

def assertAllClose(self, A, B):
assert np.allclose(A, B, rtol=1e-09, atol=1e-08)

# Define a bunch of very small tree-sequences for testing a collection of parameters on
def test_simple_n_10_no_recombination(self):
ts = msprime.simulate(
10,
recombination_rate=0,
mutation_rate=0.5,
random_seed=42,
)
assert ts.num_sites > 3
self.verify(ts)

def test_simple_n_6(self):
ts = msprime.simulate(
6,
recombination_rate=2,
mutation_rate=7,
random_seed=42,
)
assert ts.num_sites > 5
self.verify(ts)

def test_simple_n_8(self):
ts = msprime.simulate(
8,
recombination_rate=2,
mutation_rate=5,
random_seed=42,
)
assert ts.num_sites > 5
self.verify(ts)

def test_simple_n_8_high_recombination(self):
ts = msprime.simulate(
8,
recombination_rate=20,
mutation_rate=5,
random_seed=42,
)
assert ts.num_trees > 15
assert ts.num_sites > 5
self.verify(ts)

def test_simple_n_16(self):
ts = msprime.simulate(
16,
recombination_rate=2,
mutation_rate=5,
random_seed=42,
)
assert ts.num_sites > 5
self.verify(ts)

# Test a bigger one.
def test_large(self, n=50, length=100000, mean_r=1e-5, mean_mu=1e-5, seed=42):
def test_large(self, n=50, length=1e5, mean_r=1e-5, mean_mu=1e-5, seed=42):
ts = msprime.simulate(
n + 1,
length=length,
Expand All @@ -236,26 +67,10 @@ def test_large(self, n=50, length=100000, mean_r=1e-5, mean_mu=1e-5, seed=42):
)
self.verify_larger(ts)

def verify(self, ts):
raise NotImplementedError()

def verify_larger(self, ts):
pass


class FBAlgorithmBase(LSBase):
"""Base for forwards backwards algorithm tests."""


class TestNonTreeMethodsHap(FBAlgorithmBase):
"""Test that the computed likelihoods are the same across all implementations."""

class TestNonTreeMethodsHap(lsbase.FBAlgorithmBase):
def verify(self, ts):
for n, m, H_vs, s, e_vs, r in self.example_parameters_haplotypes(ts):
e_sv = e_vs.T
H_sv = H_vs.T

# variants x samples
for n, m, H_vs, s, e_vs, r in self.get_examples_pars_haploid(ts):
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False)
B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r)
self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m))
Expand All @@ -269,9 +84,6 @@ def verify(self, ts):
def verify_larger(self, ts):
# variants x samples
for n, m, H_vs, s, e_vs, r in self.example_parameters_haplotypes_larger(ts):
e_sv = e_vs.T
H_sv = H_vs.T

F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False)
B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r)
self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m))
Expand All @@ -283,14 +95,13 @@ def verify_larger(self, ts):
self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m))


class TestNonTreeMethodsDip(FBAlgorithmBase):
"""Test that the computed likelihoods are the same across all implementations."""

class TestNonTreeMethodsDip(lsbase.FBAlgorithmBase):
def verify(self, ts):
for n, m, G_vs, s, e_vs, r in self.example_parameters_genotypes(ts):
for n, m, G_vs, s, e_vs, r in self.get_examples_pars_diploid(ts):
F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True)
B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r)
self.assertAllClose(np.sum(F_vs * B_vs, (1, 2)), np.ones(m))

F_tmp, c_tmp, ll_tmp = fbd.forwards_ls_dip(
n, m, G_vs, s, e_vs, r, norm=False
)
Expand Down Expand Up @@ -366,47 +177,46 @@ def verify_larger(self, ts):
self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m))


class VitAlgorithmBase(LSBase):
"""Base for viterbi algoritm tests."""


class TestNonTreeViterbiHap(VitAlgorithmBase):
"""Test that the computed log-likelihoods are the same across all implementations."""

class TestNonTreeViterbiHap(lsbase.ViterbiAlgorithmBase):
def verify(self, ts):
for n, m, H_vs, s, e_vs, r in self.example_parameters_haplotypes(ts):
for n, m, H_vs, s, e_vs, r in self.get_examples_pars_haploid(ts):
V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(n, m, H_vs, s, e_vs, r)
path_vs = vh.backwards_viterbi_hap(m, V_vs[m - 1, :], P_vs)
ll_check = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r)
self.assertAllClose(ll_vs, ll_check)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec(
n, m, H_vs, s, e_vs, r
)
path_tmp = vh.backwards_viterbi_hap(m, V_tmp[m - 1, :], P_tmp)
ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem(
n, m, H_vs, s, e_vs, r
)
path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling(
n, m, H_vs, s, e_vs, r
)
path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling(
n, m, H_vs, s, e_vs, r
)
path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling(
n, m, H_vs, s, e_vs, r
)
Expand Down Expand Up @@ -490,11 +300,9 @@ def verify_larger(self, ts):
self.assertAllClose(ll_vs, ll_tmp)


class TestNonTreeViterbiDip(VitAlgorithmBase):
"""Test that the computed log-likelihoods are the same across all implementations."""

class TestNonTreeViterbiDip(lsbase.ViterbiAlgorithmBase):
def verify(self, ts):
for n, m, G_vs, s, e_vs, r in self.example_parameters_genotypes(ts):
for n, m, G_vs, s, e_vs, r in self.get_examples_pars_diploid(ts):
V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_naive(n, m, G_vs, s, e_vs, r)
path_vs = vd.backwards_viterbi_dip(m, V_vs[m - 1, :, :], P_vs)
phased_path_vs = vd.get_phased_path(n, path_vs)
Expand Down

0 comments on commit ec8ed22

Please sign in to comment.