From 114efbb5cb9a92c8901e94bcdd627ccd61502db0 Mon Sep 17 00:00:00 2001 From: szhan Date: Fri, 19 Apr 2024 16:29:41 +0100 Subject: [PATCH] Major refactor --- tests/lsbase.py | 12 +-- tests/test_API.py | 48 +++++++++ tests/test_API_multiallelic.py | 188 ++------------------------------- 3 files changed, 60 insertions(+), 188 deletions(-) diff --git a/tests/lsbase.py b/tests/lsbase.py index a5bfc39..f4194cb 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -142,7 +142,7 @@ def get_emission_prob_matrix_diploid(self, mu, m): def get_examples_parameters_diploid(self, ts, seed=42): np.random.seed(seed) - H, G, genotypes = self.example_genotypes(ts) + H, G, genotypes = self.get_examples_diploid(ts) n = H.shape[1] m = ts.get_num_sites() @@ -151,7 +151,7 @@ def get_examples_parameters_diploid(self, ts, seed=42): mu = np.zeros(m) + 0.01 r[0] = 0 - e = self.genotype_emission(mu, m) + e = self.get_emission_prob_matrix_diploid(mu, m) for s in genotypes: yield n, m, G, s, e, r, mu @@ -160,18 +160,18 @@ def get_examples_parameters_diploid(self, ts, seed=42): rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] - e = self.genotype_emission(mu, m) + e = self.get_emission_prob_matrix_diploid(mu, m) for s, r, mu in itertools.product(genotypes, rs, mus): r[0] = 0 - e = self.genotype_emission(mu, m) + e = self.get_emission_prob_matrix_diploid(mu, m) yield n, m, G, s, e, r, mu def get_examples_parameters_larger_diploid( self, ts, mean_r=1e-5, mean_mu=1e-5, seed=42 ): np.random.seed(seed) - H, G, genotypes = self.example_genotypes(ts) + H, G, genotypes = self.get_examples_diploid(ts) m = ts.get_num_sites() n = H.shape[1] @@ -181,7 +181,7 @@ def get_examples_parameters_larger_diploid( mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - e = self.genotype_emission(mu, m) + e = self.get_emission_prob_matrix_diploid(mu, m) for s in genotypes: yield n, m, G, s, e, r, mu diff --git a/tests/test_API.py b/tests/test_API.py index 677f75c..1854e43 100644 --- a/tests/test_API.py +++ b/tests/test_API.py @@ -49,6 +49,22 @@ def verify(self, ts): class TestMethodsDiploid(lsbase.FBAlgorithmBase): + def test_simple_n_10_no_recombination(self): + ts = self.get_simple_n_10_no_recombination() + self.verify(ts) + + def test_simple_n_6(self): + ts = self.get_simple_n_6() + self.verify(ts) + + def test_simple_n_8(self): + ts = self.get_simple_n_8() + self.verify(ts) + + def test_simple_n_16(self): + ts = self.get_simple_n_16() + self.verify(ts) + def verify(self, ts): for n, m, G_vs, s, e_vs, r, mu in self.get_examples_parameters_diploid(ts): F_vs, c_vs, ll_vs = fbd.forward_ls_dip_loop( @@ -63,6 +79,22 @@ def verify(self, ts): class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase): + def test_simple_n_10_no_recombination(self): + ts = self.get_simple_n_10_no_recombination() + self.verify(ts) + + def test_simple_n_6(self): + ts = self.get_simple_n_6() + self.verify(ts) + + def test_simple_n_8(self): + ts = self.get_simple_n_8() + self.verify(ts) + + def test_simple_n_16(self): + ts = self.get_simple_n_16() + self.verify(ts) + def verify(self, ts): for n, m, H_vs, s, e_vs, r, mu in self.get_examples_parameters_haploid(ts): V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling( @@ -76,6 +108,22 @@ def verify(self, ts): class TestViterbiDiploid(lsbase.ViterbiAlgorithmBase): + def test_simple_n_10_no_recombination(self): + ts = self.get_simple_n_10_no_recombination() + self.verify(ts) + + def test_simple_n_6(self): + ts = self.get_simple_n_6() + self.verify(ts) + + def test_simple_n_8(self): + ts = self.get_simple_n_8() + self.verify(ts) + + def test_simple_n_16(self): + ts = self.get_simple_n_16() + self.verify(ts) + def verify(self, ts): for n, m, G_vs, s, e_vs, r, mu in self.get_examples_parameters_diploid(ts): V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(n, m, G_vs, s, e_vs, r) diff --git a/tests/test_API_multiallelic.py b/tests/test_API_multiallelic.py index 5c69853..b034d18 100644 --- a/tests/test_API_multiallelic.py +++ b/tests/test_API_multiallelic.py @@ -1,184 +1,14 @@ -import itertools -import pytest - -import numpy as np - -import msprime -import tskit - +from . import lsbase import lshmm as ls -import lshmm.core as core import lshmm.fb_diploid as fbd import lshmm.fb_haploid as fbh import lshmm.vit_diploid as vd import lshmm.vit_haploid as vh -class LSBase: - """Superclass of Li and Stephens tests.""" - - def example_haplotypes(self, ts, seed=42): - H = ts.genotype_matrix() - s = H[:, 0].reshape(1, H.shape[0]) - H = H[:, 1:] - - haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] - s_tmp = s.copy() - s_tmp[0, -1] = core.MISSING - haplotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = core.MISSING - haplotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, :] = core.MISSING - haplotypes.append(s_tmp) - - return H, haplotypes - - def haplotype_emission(self, mu, m, n_alleles, scale_mutation_based_on_n_alleles): - # Define the emission probability matrix - e = np.zeros((m, 2)) - if isinstance(mu, float): - mu = mu * np.ones(m) - - if scale_mutation_based_on_n_alleles: - e[:, 0] = mu - mu * np.equal( - n_alleles, np.ones(m) - ) # Added boolean in case we're at an invariant site - e[:, 1] = 1 - (n_alleles - 1) * mu - else: - for j in range(m): - if n_alleles[j] == 1: # In case we're at an invariant site - e[j, 0] = 0 - e[j, 1] = 1 - else: - e[j, 0] = mu[j] / (n_alleles[j] - 1) - e[j, 1] = 1 - mu[j] - return e - - def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True): - """Returns an iterator over combinations of haplotype, recombination and - mutation probabilities.""" - np.random.seed(seed) - H, haplotypes = self.example_haplotypes(ts) - n = H.shape[1] - m = ts.get_num_sites() - - def _get_num_alleles(ref_haps, query): - assert ref_haps.shape[0] == query.shape[1] - num_sites = ref_haps.shape[0] - num_alleles = np.zeros(num_sites, dtype=np.int8) - exclusion_set = np.array([core.MISSING]) - for i in range(num_sites): - uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i])) - num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) - assert np.all(num_alleles >= 0), "Number of alleles cannot be zero." - return num_alleles - - # Here we have equal mutation and recombination - r = np.zeros(m) + 0.01 - mu = np.zeros(m) + 0.01 - r[0] = 0 - - for s in haplotypes: - # Must be calculated from the genotype matrix because we can now get back mutations that - # result in the number of alleles being higher than the number of alleles in the reference panel. - n_alleles = _get_num_alleles(H, s) - e = self.haplotype_emission( - mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation - ) - yield n, m, H, s, e, r, mu - - # Mixture of random and extremes - rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] - mus = [np.zeros(m) + 0.2, np.zeros(m) + 1e-6, np.random.rand(m) * 0.2] - - for s, r, mu in itertools.product(haplotypes, rs, mus): - r[0] = 0 - n_alleles = _get_num_alleles(H, s) - e = self.haplotype_emission( - mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation - ) - yield n, m, H, s, e, r, mu - - def assertAllClose(self, A, B): - assert np.allclose(A, B, rtol=1e-9, atol=0.0) - - # Define a bunch of very small tree-sequences for testing a collection of parameters on - def test_simple_n_10_no_recombination(self): - ts = msprime.sim_ancestry( - samples=10, - recombination_rate=0, - sequence_length=10, - population_size=10000, - random_seed=42, - ) - ts = msprime.sim_mutations( - ts, - rate=1e-5, - random_seed=42, - ) - assert ts.num_sites > 3 - self.verify(ts) - - def test_simple_n_6(self): - ts = msprime.sim_ancestry( - samples=6, - recombination_rate=1e-4, - random_seed=42, - sequence_length=40, - population_size=10000, - ) - ts = msprime.sim_mutations( - ts, - rate=1e-3, - random_seed=42, - ) - assert ts.num_sites > 5 - self.verify(ts) - - def test_simple_n_8(self): - ts = msprime.sim_ancestry( - samples=8, - recombination_rate=1e-4, - sequence_length=20, - population_size=10000, - random_seed=42, - ) - ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42) - assert ts.num_sites > 5 - assert ts.num_trees > 15 - self.verify(ts) - - def test_simple_n_16(self): - ts = msprime.sim_ancestry( - samples=16, - recombination_rate=1e-2, - sequence_length=20, - population_size=10000, - random_seed=42, - ) - ts = msprime.sim_mutations( - ts, - rate=1e-4, - random_seed=42, - ) - assert ts.num_sites > 5 - self.verify(ts) - +class TestMethodsHaploid(lsbase.FBAlgorithmBase): def verify(self, ts): - raise NotImplementedError() - - -class FBAlgorithmBase(LSBase): - """Base for forwards backwards algorithm tests.""" - - -class TestMethodsHap(FBAlgorithmBase): - """Test that the computed likelihood is the same across all implementations.""" - - def verify(self, ts): - for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): + for n, m, H_vs, s, e_vs, r, mu in self.get_examples_parameters_haploid(ts): F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r) B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) F, c, ll = ls.forwards(H_vs, s, r, p_mutation=mu) @@ -187,7 +17,7 @@ def verify(self, ts): self.assertAllClose(B, B_vs) self.assertAllClose(ll_vs, ll) - for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes( + for n, m, H_vs, s, e_vs, r, mu in self.get_examples_parameters_haploid( ts, scale_mutation=False ): F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r) @@ -203,15 +33,9 @@ def verify(self, ts): self.assertAllClose(ll_vs, ll) -class VitAlgorithmBase(LSBase): - """Base for Viterbi algorithm tests.""" - - -class TestViterbiHap(VitAlgorithmBase): - """Test that the computed log-likelihoods are the same across all implementations.""" - +class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase): def verify(self, ts): - for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): + for n, m, H_vs, s, e_vs, r, mu in self.get_examples_parameters_haploid(ts): V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling( n, m, H_vs, s, e_vs, r )