diff --git a/lshmm/vit_diploid.py b/lshmm/vit_diploid.py index 83312fa..d4ef207 100644 --- a/lshmm/vit_diploid.py +++ b/lshmm/vit_diploid.py @@ -315,7 +315,7 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r): def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r): - """Fully vectorised naive LS diploid Viterbi algorithm using numpy.""" + """Fully vectorised naive implementation using Numpy.""" char_both = np.eye(n * n).ravel().reshape((n, n, n, n)) char_col = np.tile(np.sum(np.eye(n * n).reshape((n, n, n, n)), 3), (n, 1, 1, 1)) char_row = np.copy(char_col).T diff --git a/tests/lsbase.py b/tests/lsbase.py new file mode 100644 index 0000000..a5bfc39 --- /dev/null +++ b/tests/lsbase.py @@ -0,0 +1,319 @@ +import itertools + +import numpy as np + +import msprime + +import lshmm.core as core + + +class LSBase: + """Base class of Li and Stephens tests.""" + + def get_num_alleles(self, ref_haps, query): + assert ref_haps.shape[0] == query.shape[1] + num_sites = ref_haps.shape[0] + num_alleles = np.zeros(num_sites, dtype=np.int8) + exclusion_set = np.array([core.MISSING]) + for i in range(num_sites): + uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i])) + num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) + assert np.all(num_alleles >= 0), "Number of alleles cannot be zero." + return num_alleles + + # Haploid + def get_examples_haploid(self, ts): + H = ts.genotype_matrix() + s = H[:, 0].reshape(1, H.shape[0]) + H = H[:, 1:] + + haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] + s_tmp = s.copy() + s_tmp[0, -1] = core.MISSING + haplotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, ts.num_sites // 2] = core.MISSING + haplotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, :] = core.MISSING + haplotypes.append(s_tmp) + + return H, haplotypes + + def get_emission_prob_matrix_haploid( + self, mu, m, n_alleles, scale_mutation_based_on_n_alleles + ): + e = np.zeros((m, 2)) + if isinstance(mu, float): + mu = mu * np.ones(m) + + if scale_mutation_based_on_n_alleles: + e[:, 0] = mu - mu * np.equal( + n_alleles, np.ones(m) + ) # Add boolean in case we're at an invariant site + e[:, 1] = 1 - (n_alleles - 1) * mu + else: + for j in range(m): + if n_alleles[j] == 1: + # In case we're at an invariant site + e[j, 0] = 0 + e[j, 1] = 1 + else: + e[j, 0] = mu[j] / (n_alleles[j] - 1) + e[j, 1] = 1 - mu[j] + return e + + def get_examples_parameters_haploid(self, ts, scale_mutation=True, seed=42): + """ + Returns an iterator over combinations of haplotypes, recombination probabilties, + and mutation probabilities. + """ + np.random.seed(seed) + H, haplotypes = self.get_examples_haploid(ts) + n = H.shape[1] + m = ts.get_num_sites() + + # Here we have equal mutation and recombination + r = np.zeros(m) + 0.01 + mu = np.zeros(m) + 0.01 + r[0] = 0 + + for s in haplotypes: + # Must be calculated from the genotype matrix because we can now get back mutations that + # result in the number of alleles being higher than the number of alleles in the reference panel. + n_alleles = self.get_num_alleles(H, s) + e = self.get_emission_prob_matrix_haploid( + mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation + ) + yield n, m, H, s, e, r, mu + + # Mixture of random and extremes + rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] + mus = [np.zeros(m) + 0.2, np.zeros(m) + 1e-6, np.random.rand(m) * 0.2] + + for s, r, mu in itertools.product(haplotypes, rs, mus): + r[0] = 0 + n_alleles = self.get_num_alleles(H, s) + e = self.get_emission_prob_matrix_haploid( + mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation + ) + yield n, m, H, s, e, r, mu + + # Diploid + def get_examples_diploid(self, ts, seed=42): + np.random.seed(seed) + H = ts.genotype_matrix() + s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0]) + H = H[:, 2:] + + genotypes = [ + s, + H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]), + ] + + s_tmp = s.copy() + s_tmp[0, -1] = core.MISSING + genotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, ts.num_sites // 2] = core.MISSING + genotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, :] = core.MISSING + genotypes.append(s_tmp) + + m = ts.get_num_sites() + n = H.shape[1] + + G = np.zeros((m, n, n)) + for i in range(m): + G[i, :, :] = np.add.outer(H[i, :], H[i, :]) + + return H, G, genotypes + + def get_emission_prob_matrix_diploid(self, mu, m): + e = np.zeros((m, 8)) + e[:, core.EQUAL_BOTH_HOM] = (1 - mu) ** 2 + e[:, core.UNEQUAL_BOTH_HOM] = mu**2 + e[:, core.BOTH_HET] = (1 - mu) ** 2 + mu**2 + e[:, core.REF_HOM_OBS_HET] = 2 * mu * (1 - mu) + e[:, core.REF_HET_OBS_HOM] = mu * (1 - mu) + e[:, core.MISSING_INDEX] = 1 + return e + + def get_examples_parameters_diploid(self, ts, seed=42): + np.random.seed(seed) + H, G, genotypes = self.example_genotypes(ts) + n = H.shape[1] + m = ts.get_num_sites() + + # Here we have equal mutation and recombination + r = np.zeros(m) + 0.01 + mu = np.zeros(m) + 0.01 + r[0] = 0 + + e = self.genotype_emission(mu, m) + + for s in genotypes: + yield n, m, G, s, e, r, mu + + # Mixture of random and extremes + rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] + mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] + + e = self.genotype_emission(mu, m) + + for s, r, mu in itertools.product(genotypes, rs, mus): + r[0] = 0 + e = self.genotype_emission(mu, m) + yield n, m, G, s, e, r, mu + + def get_examples_parameters_larger_diploid( + self, ts, mean_r=1e-5, mean_mu=1e-5, seed=42 + ): + np.random.seed(seed) + H, G, genotypes = self.example_genotypes(ts) + + m = ts.get_num_sites() + n = H.shape[1] + + r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) + r[0] = 0 + + mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) + + e = self.genotype_emission(mu, m) + + for s in genotypes: + yield n, m, G, s, e, r, mu + + # Prepare simple example datasets. + def get_simple_n_10_no_recombination(self, seed=42): + ts = msprime.simulate( + 10, + recombination_rate=0, + mutation_rate=0.5, + random_seed=seed, + ) + assert ts.num_sites > 3 + return ts + + def get_simple_n_6(self, seed=42): + ts = msprime.simulate( + 6, + recombination_rate=2, + mutation_rate=7, + random_seed=seed, + ) + assert ts.num_sites > 5 + return ts + + def get_simple_n_8(self, seed=42): + ts = msprime.simulate( + 8, + recombination_rate=2, + mutation_rate=5, + random_seed=seed, + ) + assert ts.num_sites > 5 + return ts + + def get_simple_n_8_high_recombination(self, seed=42): + ts = msprime.simulate( + 8, + recombination_rate=20, + mutation_rate=5, + random_seed=seed, + ) + assert ts.num_trees > 15 + assert ts.num_sites > 5 + return ts + + def get_simple_n_16(self, seed=42): + ts = msprime.simulate( + 16, + recombination_rate=2, + mutation_rate=5, + random_seed=seed, + ) + assert ts.num_sites > 5 + return ts + + # Prepare example datasets with multiallelic sites. + def get_multiallelic_n_10_no_recombination(self, seed=42): + ts = msprime.sim_ancestry( + samples=10, + recombination_rate=0, + sequence_length=10, + population_size=1e4, + random_seed=seed, + ) + ts = msprime.sim_mutations( + ts, + rate=1e-5, + random_seed=seed, + ) + assert ts.num_sites > 3 + return ts + + def get_multiallelic_n_6(self, seed=42): + ts = msprime.sim_ancestry( + samples=6, + recombination_rate=1e-4, + sequence_length=40, + population_size=1e4, + random_seed=seed, + ) + ts = msprime.sim_mutations( + ts, + rate=1e-3, + random_seed=seed, + ) + assert ts.num_sites > 5 + return ts + + def get_multiallelic_n_8(self, seed=42): + ts = msprime.sim_ancestry( + samples=8, + recombination_rate=1e-4, + sequence_length=20, + population_size=1e4, + random_seed=seed, + ) + ts = msprime.sim_mutations( + ts, + rate=1e-4, + random_seed=seed, + ) + assert ts.num_sites > 5 + assert ts.num_trees > 15 + return ts + + def get_multiallelic_n_16(self, seed=42): + ts = msprime.sim_ancestry( + samples=16, + recombination_rate=1e-2, + sequence_length=20, + population_size=1e4, + random_seed=seed, + ) + ts = msprime.sim_mutations( + ts, + rate=1e-4, + random_seed=seed, + ) + assert ts.num_sites > 5 + return ts + + def verify(self, ts): + raise NotImplementedError() + + def assertAllClose(self, A, B): + assert np.allclose(A, B, rtol=1e-9, atol=0.0) + + +class FBAlgorithmBase(LSBase): + """Base for testing forwards-backwards algorithms.""" + + +class ViterbiAlgorithmBase(LSBase): + """Base for testing Viterbi algoritms.""" diff --git a/tests/test_API.py b/tests/test_API.py index 0df87ce..677f75c 100644 --- a/tests/test_API.py +++ b/tests/test_API.py @@ -6,6 +6,7 @@ import msprime import tskit +from . import lsbase import lshmm as ls import lshmm.core as core import lshmm.fb_diploid as fbd @@ -14,247 +15,27 @@ import lshmm.vit_haploid as vh -class LSBase: - """Superclass of Li and Stephens tests.""" - - def example_haplotypes(self, ts, seed=42): - H = ts.genotype_matrix() - s = H[:, 0].reshape(1, H.shape[0]) - H = H[:, 1:] - - haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] - s_tmp = s.copy() - s_tmp[0, -1] = core.MISSING - haplotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = core.MISSING - haplotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, :] = core.MISSING - haplotypes.append(s_tmp) - - return H, haplotypes - - def haplotype_emission(self, mu, m, n_alleles, scale_mutation_based_on_n_alleles): - # Define the emission probability matrix - e = np.zeros((m, 2)) - if isinstance(mu, float): - mu = mu * np.ones(m) - - if scale_mutation_based_on_n_alleles: - e[:, 0] = mu - mu * np.equal( - n_alleles, np.ones(m) - ) # Added boolean in case we're at an invariant site - e[:, 1] = 1 - (n_alleles - 1) * mu - else: - for j in range(m): - if n_alleles[j] == 1: # In case we're at an invariant site - e[j, 0] = 0 - e[j, 1] = 1 - else: - e[j, 0] = mu[j] / (n_alleles[j] - 1) - e[j, 1] = 1 - mu[j] - return e - - def genotype_emission(self, mu, m): - # Define the emission probability matrix - e = np.zeros((m, 8)) - e[:, core.EQUAL_BOTH_HOM] = (1 - mu) ** 2 - e[:, core.UNEQUAL_BOTH_HOM] = mu**2 - e[:, core.BOTH_HET] = (1 - mu) ** 2 + mu**2 - e[:, core.REF_HOM_OBS_HET] = 2 * mu * (1 - mu) - e[:, core.REF_HET_OBS_HOM] = mu * (1 - mu) - e[:, core.MISSING_INDEX] = 1 - return e - - def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True): - """ - Returns an iterator over combinations of haplotype, recombination and mutation probabilities. - """ - np.random.seed(seed) - H, haplotypes = self.example_haplotypes(ts) - n = H.shape[1] - m = ts.get_num_sites() - - def _get_num_alleles(ref_haps, query): - assert ref_haps.shape[0] == query.shape[1] - num_sites = ref_haps.shape[0] - num_alleles = np.zeros(num_sites, dtype=np.int8) - exclusion_set = np.array([core.MISSING]) - for i in range(num_sites): - uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i])) - num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) - assert np.all(num_alleles >= 0), "Number of alleles cannot be zero." - return num_alleles - - # Here we have equal mutation and recombination - r = np.zeros(m) + 0.01 - mu = np.zeros(m) + 0.01 - r[0] = 0 - - for s in haplotypes: - n_alleles = _get_num_alleles(H, s) - e = self.haplotype_emission( - mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation - ) - yield n, m, H, s, e, r, mu - - # Mixture of random and extremes - rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] - mus = [np.zeros(m) + 0.2, np.zeros(m) + 1e-6, np.random.rand(m) * 0.2] - - for s, r, mu in itertools.product(haplotypes, rs, mus): - r[0] = 0 - n_alleles = _get_num_alleles(H, s) - e = self.haplotype_emission( - mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation - ) - yield n, m, H, s, e, r, mu - - def example_genotypes(self, ts, seed=42): - np.random.seed(seed) - H = ts.genotype_matrix() - s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0]) - H = H[:, 2:] - - genotypes = [ - s, - H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]), - ] - - s_tmp = s.copy() - s_tmp[0, -1] = core.MISSING - genotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = core.MISSING - genotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, :] = core.MISSING - genotypes.append(s_tmp) - - m = ts.get_num_sites() - n = H.shape[1] - - G = np.zeros((m, n, n)) - for i in range(m): - G[i, :, :] = np.add.outer(H[i, :], H[i, :]) - - return H, G, genotypes - - def example_parameters_genotypes(self, ts, seed=42): - np.random.seed(seed) - H, G, genotypes = self.example_genotypes(ts) - n = H.shape[1] - m = ts.get_num_sites() - - # Here we have equal mutation and recombination - r = np.zeros(m) + 0.01 - mu = np.zeros(m) + 0.01 - r[0] = 0 - - e = self.genotype_emission(mu, m) - - for s in genotypes: - yield n, m, G, s, e, r, mu - - # Mixture of random and extremes - rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] - mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] - - e = self.genotype_emission(mu, m) - - for s, r, mu in itertools.product(genotypes, rs, mus): - r[0] = 0 - e = self.genotype_emission(mu, m) - yield n, m, G, s, e, r, mu - - def example_parameters_genotypes_larger( - self, ts, seed=42, mean_r=1e-5, mean_mu=1e-5 - ): - np.random.seed(seed) - H, G, genotypes = self.example_genotypes(ts) - - m = ts.get_num_sites() - n = H.shape[1] - - r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - r[0] = 0 - - mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - - e = self.genotype_emission(mu, m) - - for s in genotypes: - yield n, m, G, s, e, r, mu - - def assertAllClose(self, A, B): - assert np.allclose(A, B, rtol=1e-9, atol=0.0) +class TestMethodsHaploid(lsbase.FBAlgorithmBase): + """Test that the computed likelihood is the same across all implementations.""" - # Define a bunch of very small tree-sequences for testing a collection of parameters on def test_simple_n_10_no_recombination(self): - ts = msprime.simulate( - samples=10, - recombination_rate=0, - mutation_rate=0.5, - random_seed=42, - ) - assert ts.num_sites > 3 + ts = self.get_simple_n_10_no_recombination() self.verify(ts) def test_simple_n_6(self): - ts = msprime.simulate( - samples=6, - recombination_rate=2, - mutation_rate=7, - random_seed=42, - ) - assert ts.num_sites > 5 + ts = self.get_simple_n_6() self.verify(ts) def test_simple_n_8(self): - ts = msprime.simulate( - samples=8, - recombination_rate=2, - mutation_rate=5, - random_seed=42, - ) - assert ts.num_sites > 5 - self.verify(ts) - - def test_simple_n_8_high_recombination(self): - ts = msprime.simulate( - samples=8, - recombination_rate=20, - mutation_rate=5, - random_seed=42, - ) - assert ts.num_trees > 15 - assert ts.num_sites > 5 + ts = self.get_simple_n_8() self.verify(ts) def test_simple_n_16(self): - ts = msprime.simulate( - samples=16, - recombination_rate=2, - mutation_rate=5, - random_seed=42, - ) - assert ts.num_sites > 5 + ts = self.get_simple_n_16() self.verify(ts) def verify(self, ts): - raise NotImplementedError() - - -class FBAlgorithmBase(LSBase): - """Base for forwards backwards algorithm tests.""" - - -class TestMethodsHap(FBAlgorithmBase): - """Test that the computed likelihood is the same across all implementations.""" - - def verify(self, ts): - for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): + for n, m, H_vs, s, e_vs, r, mu in self.get_examples_parameters_haploid(ts): F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r) B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) F, c, ll = ls.forwards(H_vs, s, r, p_mutation=mu) @@ -267,11 +48,9 @@ def verify(self, ts): B = ls.backwards(H_vs, s, c, r, mu) -class TestMethodsDip(FBAlgorithmBase): - """Test that the computed likelihood is the same across all implementations.""" - +class TestMethodsDiploid(lsbase.FBAlgorithmBase): def verify(self, ts): - for n, m, G_vs, s, e_vs, r, mu in self.example_parameters_genotypes(ts): + for n, m, G_vs, s, e_vs, r, mu in self.get_examples_parameters_diploid(ts): F_vs, c_vs, ll_vs = fbd.forward_ls_dip_loop( n, m, G_vs, s, e_vs, r, norm=True ) @@ -283,15 +62,9 @@ def verify(self, ts): self.assertAllClose(ll_vs, ll) -class VitAlgorithmBase(LSBase): - """Base for Viterbi algoritm tests.""" - - -class TestViterbiHap(VitAlgorithmBase): - """Test that the computed log-likelihood is the same across all implementations.""" - +class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase): def verify(self, ts): - for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): + for n, m, H_vs, s, e_vs, r, mu in self.get_examples_parameters_haploid(ts): V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling( n, m, H_vs, s, e_vs, r ) @@ -302,11 +75,9 @@ def verify(self, ts): self.assertAllClose(path_vs, path) -class TestViterbiDip(VitAlgorithmBase): - """Test that the computed log-likelihood is the same across all implementations.""" - +class TestViterbiDiploid(lsbase.ViterbiAlgorithmBase): def verify(self, ts): - for n, m, G_vs, s, e_vs, r, mu in self.example_parameters_genotypes(ts): + for n, m, G_vs, s, e_vs, r, mu in self.get_examples_parameters_diploid(ts): V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(n, m, G_vs, s, e_vs, r) path_vs = vd.backwards_viterbi_dip(m, V_vs, P_vs) phased_path_vs = vd.get_phased_path(n, path_vs) diff --git a/tests/test_API_multiallelic.py b/tests/test_API_multiallelic.py index 5c01c3d..5c69853 100644 --- a/tests/test_API_multiallelic.py +++ b/tests/test_API_multiallelic.py @@ -204,11 +204,11 @@ def verify(self, ts): class VitAlgorithmBase(LSBase): - """Base for viterbi algoritm tests.""" + """Base for Viterbi algorithm tests.""" class TestViterbiHap(VitAlgorithmBase): - """Test that we have the same log-likelihood across all implementations.""" + """Test that the computed log-likelihoods are the same across all implementations.""" def verify(self, ts): for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts):