diff --git a/lshmm/api.py b/lshmm/api.py index 35a461c..e394cb2 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -140,7 +140,10 @@ def set_emission_probabilities( ) else: emission_probs = core.get_emission_matrix_diploid( - mu=prob_mutation, num_sites=num_sites + mu=prob_mutation, + num_sites=num_sites, + num_alleles=num_alleles, + scale_mutation_rate=scale_mutation_rate, ) return emission_probs diff --git a/lshmm/core.py b/lshmm/core.py index 1a722b1..aa958f7 100644 --- a/lshmm/core.py +++ b/lshmm/core.py @@ -224,15 +224,35 @@ def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate) return e -def get_emission_matrix_diploid(mu, num_sites): - e = np.zeros((num_sites, 8)) - e[:, EQUAL_BOTH_HOM] = (1 - mu) ** 2 - e[:, UNEQUAL_BOTH_HOM] = mu**2 - e[:, BOTH_HET] = (1 - mu) ** 2 + mu**2 - e[:, REF_HOM_OBS_HET] = 2 * mu * (1 - mu) - e[:, REF_HET_OBS_HOM] = mu * (1 - mu) - e[:, MISSING_INDEX] = 1 - return e +def get_emission_matrix_diploid(mu, num_sites, num_alleles, scale_mutation_rate): + assert len(mu) == len( + num_alleles + ), "Arrays of mutation probability and number of alleles are unequal in length." + if isinstance(mu, float): + mu = np.zeros(num_sites, dtype=np.float64) + mu + prob_mutation = np.zeros(num_sites, dtype=np.float64) - np.inf + prob_no_mutation = np.zeros(num_sites, dtype=np.float64) - np.inf + emission_matrix = np.zeros((num_sites, 8), dtype=np.float64) - np.inf + for i in range(num_sites): + if num_alleles[i] == 1: + # Set probabilities at invariant sites. + prob_mutation[i] = 0 + prob_no_mutation[i] = 1 + else: + if scale_mutation_rate: + prob_mutation[i] = mu[i] + prob_no_mutation[i] = 1 - (num_alleles[i] - 1) * mu[i] + else: + prob_mutation[i] = mu[i] / (num_alleles[i] - 1) + prob_no_mutation[i] = 1 - mu[i] + for i in range(num_sites): + emission_matrix[i, EQUAL_BOTH_HOM] = prob_no_mutation[i] ** 2 + emission_matrix[i, UNEQUAL_BOTH_HOM] = prob_mutation[i] ** 2 + emission_matrix[i, BOTH_HET] = prob_no_mutation[i] ** 2 + prob_mutation[i] ** 2 + emission_matrix[i, REF_HOM_OBS_HET] = 2 * prob_mutation[i] * prob_no_mutation[i] + emission_matrix[i, REF_HET_OBS_HOM] = prob_mutation[i] * prob_no_mutation[i] + emission_matrix[i, MISSING_INDEX] = 1 + return emission_matrix @jit.numba_njit diff --git a/tests/lsbase.py b/tests/lsbase.py index c7c8ddf..ebab957 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -147,11 +147,19 @@ def get_examples_pars( num_alleles = core.get_num_alleles(H, query) if ploidy == 1: e = core.get_emission_matrix_haploid( - mu, m, num_alleles, scale_mutation_rate + mu=mu, + num_sites=m, + num_alleles=num_alleles, + scale_mutation_rate=scale_mutation_rate, ) yield n, m, H, query, e, r, mu else: - e = core.get_emission_matrix_diploid(mu, m) + e = core.get_emission_matrix_diploid( + mu=mu, + num_sites=m, + num_alleles=num_alleles, + scale_mutation_rate=scale_mutation_rate, + ) yield n, m, H, query, e, r, mu # Prepare simple example datasets. diff --git a/tests/test_API.py b/tests/test_API.py index d5fe793..64f341c 100644 --- a/tests/test_API.py +++ b/tests/test_API.py @@ -107,6 +107,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs) s = core.convert_haplotypes_to_unphased_genotypes(query) num_alleles = core.get_num_alleles(H_vs, query) + F_vs, c_vs, ll_vs = fbd.forward_ls_dip_loop( n=n, m=m, @@ -148,39 +149,65 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): self.assertAllClose(B, B_vs) self.assertAllClose(ll_vs, ll) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n10_no_recomb(self, include_ancestors): + def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n6(self, include_ancestors): + def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n6() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n8(self, include_ancestors): + def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n8() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n8_high_recomb(self, include_ancestors): + def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n8_high_recomb() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n16(self, include_ancestors): + def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n16() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_larger(self, include_ancestors): + def test_ts_larger(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_custom_pars( ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5 ) self.verify( ts, - scale_mutation_rate=True, + scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, ) @@ -264,6 +291,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs) s = core.convert_haplotypes_to_unphased_genotypes(query) num_alleles = core.get_num_alleles(H_vs, query) + V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem( n=n, m=m, @@ -282,41 +310,68 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): prob_mutation=mu, scale_mutation_rate=scale_mutation_rate, ) + self.assertAllClose(ll_vs, ll) self.assertAllClose(phased_path_vs, path) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n10_no_recomb(self, include_ancestors): + def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n6(self, include_ancestors): + def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n6() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n8(self, include_ancestors): + def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n8() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n8_high_recomb(self, include_ancestors): + def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n8_high_recomb() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n16(self, include_ancestors): + def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n16() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_larger(self, include_ancestors): + def test_ts_larger(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_custom_pars( ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5 ) self.verify( ts, - scale_mutation_rate=True, + scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, ) diff --git a/tests/test_non_tree.py b/tests/test_non_tree.py index f2f9d5a..8521558 100644 --- a/tests/test_non_tree.py +++ b/tests/test_non_tree.py @@ -128,76 +128,86 @@ def verify( ) self.assertAllClose(ll_vs, ll_tmp) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) @pytest.mark.parametrize("normalise", [True, False]) - def test_ts_simple_n10_no_recomb(self, include_ancestors, normalise): + def test_ts_simple_n10_no_recomb( + self, scale_mutation_rate, include_ancestors, normalise + ): ts = self.get_ts_simple_n10_no_recomb() # Test extreme rates only when normalising, # because they can lead to pathological cases. include_extreme_rates = normalise self.verify( ts, - scale_mutation_rate=True, + scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, normalise=normalise, include_extreme_rates=include_extreme_rates, ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) @pytest.mark.parametrize("normalise", [True, False]) - def test_ts_simple_n6(self, include_ancestors, normalise): + def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors, normalise): ts = self.get_ts_simple_n6() include_extreme_rates = normalise self.verify( ts, - scale_mutation_rate=True, + scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, normalise=normalise, include_extreme_rates=include_extreme_rates, ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) @pytest.mark.parametrize("normalise", [True, False]) - def test_ts_simple_n8(self, include_ancestors, normalise): + def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors, normalise): ts = self.get_ts_simple_n8() include_extreme_rates = normalise self.verify( ts, - scale_mutation_rate=True, + scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, normalise=normalise, include_extreme_rates=include_extreme_rates, ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) @pytest.mark.parametrize("normalise", [True, False]) - def test_ts_simple_n8_high_recomb(self, include_ancestors, normalise): + def test_ts_simple_n8_high_recomb( + self, scale_mutation_rate, include_ancestors, normalise + ): ts = self.get_ts_simple_n8_high_recomb() include_extreme_rates = normalise self.verify( ts, - scale_mutation_rate=True, + scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, normalise=normalise, include_extreme_rates=include_extreme_rates, ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) @pytest.mark.parametrize("normalise", [True, False]) - def test_ts_simple_n16(self, include_ancestors, normalise): + def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors, normalise): ts = self.get_ts_simple_n16() include_extreme_rates = normalise self.verify( ts, - scale_mutation_rate=True, + scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, normalise=normalise, include_extreme_rates=include_extreme_rates, ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) @pytest.mark.parametrize("normalise", [True, False]) - def test_ts_larger(self, include_ancestors, normalise): + def test_ts_larger(self, scale_mutation_rate, include_ancestors, normalise): ts = self.get_ts_custom_pars( ref_panel_size=45, length=1e5, @@ -207,7 +217,7 @@ def test_ts_larger(self, include_ancestors, normalise): include_extreme_rates = normalise self.verify( ts, - scale_mutation_rate=True, + scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, normalise=normalise, include_extreme_rates=include_extreme_rates, @@ -389,38 +399,64 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): self.assertAllClose(ll_tmp, path_ll_tmp) self.assertAllClose(ll_vs, ll_tmp) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n10_no_recomb(self, include_ancestors): + def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n6(self, include_ancestors): + def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n6() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n8(self, include_ancestors): + def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n8() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n8_high_recomb(self, include_ancestors): + def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n8_high_recomb() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_simple_n16(self, include_ancestors): + def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n16() - self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [False]) - def test_ts_larger(self, include_ancestors): + def test_ts_larger(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_custom_pars( ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5 ) self.verify( ts, - scale_mutation_rate=True, + scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, )