Skip to content

Commit

Permalink
Major refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 19, 2024
1 parent f1b99b3 commit 114efbb
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 188 deletions.
12 changes: 6 additions & 6 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_emission_prob_matrix_diploid(self, mu, m):

def get_examples_parameters_diploid(self, ts, seed=42):
np.random.seed(seed)
H, G, genotypes = self.example_genotypes(ts)
H, G, genotypes = self.get_examples_diploid(ts)
n = H.shape[1]
m = ts.get_num_sites()

Expand All @@ -151,7 +151,7 @@ def get_examples_parameters_diploid(self, ts, seed=42):
mu = np.zeros(m) + 0.01
r[0] = 0

e = self.genotype_emission(mu, m)
e = self.get_emission_prob_matrix_diploid(mu, m)

for s in genotypes:
yield n, m, G, s, e, r, mu
Expand All @@ -160,18 +160,18 @@ def get_examples_parameters_diploid(self, ts, seed=42):
rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)]
mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33]

e = self.genotype_emission(mu, m)
e = self.get_emission_prob_matrix_diploid(mu, m)

for s, r, mu in itertools.product(genotypes, rs, mus):
r[0] = 0
e = self.genotype_emission(mu, m)
e = self.get_emission_prob_matrix_diploid(mu, m)
yield n, m, G, s, e, r, mu

def get_examples_parameters_larger_diploid(
self, ts, mean_r=1e-5, mean_mu=1e-5, seed=42
):
np.random.seed(seed)
H, G, genotypes = self.example_genotypes(ts)
H, G, genotypes = self.get_examples_diploid(ts)

m = ts.get_num_sites()
n = H.shape[1]
Expand All @@ -181,7 +181,7 @@ def get_examples_parameters_larger_diploid(

mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2)

e = self.genotype_emission(mu, m)
e = self.get_emission_prob_matrix_diploid(mu, m)

for s in genotypes:
yield n, m, G, s, e, r, mu
Expand Down
48 changes: 48 additions & 0 deletions tests/test_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ def verify(self, ts):


class TestMethodsDiploid(lsbase.FBAlgorithmBase):
def test_simple_n_10_no_recombination(self):
ts = self.get_simple_n_10_no_recombination()
self.verify(ts)

def test_simple_n_6(self):
ts = self.get_simple_n_6()
self.verify(ts)

def test_simple_n_8(self):
ts = self.get_simple_n_8()
self.verify(ts)

def test_simple_n_16(self):
ts = self.get_simple_n_16()
self.verify(ts)

def verify(self, ts):
for n, m, G_vs, s, e_vs, r, mu in self.get_examples_parameters_diploid(ts):
F_vs, c_vs, ll_vs = fbd.forward_ls_dip_loop(
Expand All @@ -63,6 +79,22 @@ def verify(self, ts):


class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase):
def test_simple_n_10_no_recombination(self):
ts = self.get_simple_n_10_no_recombination()
self.verify(ts)

def test_simple_n_6(self):
ts = self.get_simple_n_6()
self.verify(ts)

def test_simple_n_8(self):
ts = self.get_simple_n_8()
self.verify(ts)

def test_simple_n_16(self):
ts = self.get_simple_n_16()
self.verify(ts)

def verify(self, ts):
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_parameters_haploid(ts):
V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling(
Expand All @@ -76,6 +108,22 @@ def verify(self, ts):


class TestViterbiDiploid(lsbase.ViterbiAlgorithmBase):
def test_simple_n_10_no_recombination(self):
ts = self.get_simple_n_10_no_recombination()
self.verify(ts)

def test_simple_n_6(self):
ts = self.get_simple_n_6()
self.verify(ts)

def test_simple_n_8(self):
ts = self.get_simple_n_8()
self.verify(ts)

def test_simple_n_16(self):
ts = self.get_simple_n_16()
self.verify(ts)

def verify(self, ts):
for n, m, G_vs, s, e_vs, r, mu in self.get_examples_parameters_diploid(ts):
V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(n, m, G_vs, s, e_vs, r)
Expand Down
188 changes: 6 additions & 182 deletions tests/test_API_multiallelic.py
Original file line number Diff line number Diff line change
@@ -1,184 +1,14 @@
import itertools
import pytest

import numpy as np

import msprime
import tskit

from . import lsbase
import lshmm as ls
import lshmm.core as core
import lshmm.fb_diploid as fbd
import lshmm.fb_haploid as fbh
import lshmm.vit_diploid as vd
import lshmm.vit_haploid as vh


class LSBase:
"""Superclass of Li and Stephens tests."""

def example_haplotypes(self, ts, seed=42):
H = ts.genotype_matrix()
s = H[:, 0].reshape(1, H.shape[0])
H = H[:, 1:]

haplotypes = [s, H[:, -1].reshape(1, H.shape[0])]
s_tmp = s.copy()
s_tmp[0, -1] = core.MISSING
haplotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, ts.num_sites // 2] = core.MISSING
haplotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, :] = core.MISSING
haplotypes.append(s_tmp)

return H, haplotypes

def haplotype_emission(self, mu, m, n_alleles, scale_mutation_based_on_n_alleles):
# Define the emission probability matrix
e = np.zeros((m, 2))
if isinstance(mu, float):
mu = mu * np.ones(m)

if scale_mutation_based_on_n_alleles:
e[:, 0] = mu - mu * np.equal(
n_alleles, np.ones(m)
) # Added boolean in case we're at an invariant site
e[:, 1] = 1 - (n_alleles - 1) * mu
else:
for j in range(m):
if n_alleles[j] == 1: # In case we're at an invariant site
e[j, 0] = 0
e[j, 1] = 1
else:
e[j, 0] = mu[j] / (n_alleles[j] - 1)
e[j, 1] = 1 - mu[j]
return e

def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True):
"""Returns an iterator over combinations of haplotype, recombination and
mutation probabilities."""
np.random.seed(seed)
H, haplotypes = self.example_haplotypes(ts)
n = H.shape[1]
m = ts.get_num_sites()

def _get_num_alleles(ref_haps, query):
assert ref_haps.shape[0] == query.shape[1]
num_sites = ref_haps.shape[0]
num_alleles = np.zeros(num_sites, dtype=np.int8)
exclusion_set = np.array([core.MISSING])
for i in range(num_sites):
uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i]))
num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set))
assert np.all(num_alleles >= 0), "Number of alleles cannot be zero."
return num_alleles

# Here we have equal mutation and recombination
r = np.zeros(m) + 0.01
mu = np.zeros(m) + 0.01
r[0] = 0

for s in haplotypes:
# Must be calculated from the genotype matrix because we can now get back mutations that
# result in the number of alleles being higher than the number of alleles in the reference panel.
n_alleles = _get_num_alleles(H, s)
e = self.haplotype_emission(
mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation
)
yield n, m, H, s, e, r, mu

# Mixture of random and extremes
rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)]
mus = [np.zeros(m) + 0.2, np.zeros(m) + 1e-6, np.random.rand(m) * 0.2]

for s, r, mu in itertools.product(haplotypes, rs, mus):
r[0] = 0
n_alleles = _get_num_alleles(H, s)
e = self.haplotype_emission(
mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation
)
yield n, m, H, s, e, r, mu

def assertAllClose(self, A, B):
assert np.allclose(A, B, rtol=1e-9, atol=0.0)

# Define a bunch of very small tree-sequences for testing a collection of parameters on
def test_simple_n_10_no_recombination(self):
ts = msprime.sim_ancestry(
samples=10,
recombination_rate=0,
sequence_length=10,
population_size=10000,
random_seed=42,
)
ts = msprime.sim_mutations(
ts,
rate=1e-5,
random_seed=42,
)
assert ts.num_sites > 3
self.verify(ts)

def test_simple_n_6(self):
ts = msprime.sim_ancestry(
samples=6,
recombination_rate=1e-4,
random_seed=42,
sequence_length=40,
population_size=10000,
)
ts = msprime.sim_mutations(
ts,
rate=1e-3,
random_seed=42,
)
assert ts.num_sites > 5
self.verify(ts)

def test_simple_n_8(self):
ts = msprime.sim_ancestry(
samples=8,
recombination_rate=1e-4,
sequence_length=20,
population_size=10000,
random_seed=42,
)
ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42)
assert ts.num_sites > 5
assert ts.num_trees > 15
self.verify(ts)

def test_simple_n_16(self):
ts = msprime.sim_ancestry(
samples=16,
recombination_rate=1e-2,
sequence_length=20,
population_size=10000,
random_seed=42,
)
ts = msprime.sim_mutations(
ts,
rate=1e-4,
random_seed=42,
)
assert ts.num_sites > 5
self.verify(ts)

class TestMethodsHaploid(lsbase.FBAlgorithmBase):
def verify(self, ts):
raise NotImplementedError()


class FBAlgorithmBase(LSBase):
"""Base for forwards backwards algorithm tests."""


class TestMethodsHap(FBAlgorithmBase):
"""Test that the computed likelihood is the same across all implementations."""

def verify(self, ts):
for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts):
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_parameters_haploid(ts):
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r)
B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r)
F, c, ll = ls.forwards(H_vs, s, r, p_mutation=mu)
Expand All @@ -187,7 +17,7 @@ def verify(self, ts):
self.assertAllClose(B, B_vs)
self.assertAllClose(ll_vs, ll)

for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_parameters_haploid(
ts, scale_mutation=False
):
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r)
Expand All @@ -203,15 +33,9 @@ def verify(self, ts):
self.assertAllClose(ll_vs, ll)


class VitAlgorithmBase(LSBase):
"""Base for Viterbi algorithm tests."""


class TestViterbiHap(VitAlgorithmBase):
"""Test that the computed log-likelihoods are the same across all implementations."""

class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase):
def verify(self, ts):
for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts):
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_parameters_haploid(ts):
V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling(
n, m, H_vs, s, e_vs, r
)
Expand Down

0 comments on commit 114efbb

Please sign in to comment.