From 063cd1d80a5b37d622d030a5700137d9e41e20ea Mon Sep 17 00:00:00 2001 From: szhan Date: Sat, 1 Jun 2024 13:05:55 +0100 Subject: [PATCH] Implement NONCOPY for diploid case --- lshmm/core.py | 7 ++-- tests/lsbase.py | 8 ++-- tests/test_API.py | 24 ++++++------ tests/test_non_tree.py | 85 ++++++++++++++++++++++-------------------- 4 files changed, 64 insertions(+), 60 deletions(-) diff --git a/lshmm/core.py b/lshmm/core.py index aa958f7..ae0fbfc 100644 --- a/lshmm/core.py +++ b/lshmm/core.py @@ -64,7 +64,6 @@ def convert_haplotypes_to_phased_genotypes(ref_panel): The only allowable allele states are 0, 1, and NONCOPY (for partial ancestral haplotypes). TODO: Handle multiallelic sites. - TODO: Handle NONCOPY. Allowable genotype values are 0, 1, 2, and NONCOPY. If either one haplotype is NONCOPY at a site, then the genotype at the site is assigned NONCOPY. @@ -75,7 +74,7 @@ def convert_haplotypes_to_phased_genotypes(ref_panel): :return: An array of reference genotypes. :rtype: numpy.ndarray """ - ALLOWED_ALLELE_STATES = np.array([0, 1], dtype=np.int32) + ALLOWED_ALLELE_STATES = np.array([0, 1, NONCOPY], dtype=np.int32) assert np.all( np.isin(np.unique(ref_panel), ALLOWED_ALLELE_STATES) ), f"Reference haplotypes contain illegal allele states." @@ -85,8 +84,8 @@ def convert_haplotypes_to_phased_genotypes(ref_panel): for i in range(num_sites): site_alleles = ref_panel[i, :] genotypes[i, :, :] = np.add.outer(site_alleles, site_alleles) - # genotypes[i, site_alleles == NONCOPY, :] = NONCOPY - # genotypes[i, :, site_alleles == NONCOPY] = NONCOPY + genotypes[i, site_alleles == NONCOPY, :] = NONCOPY + genotypes[i, :, site_alleles == NONCOPY] = NONCOPY return genotypes diff --git a/tests/lsbase.py b/tests/lsbase.py index ebab957..0d94570 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -71,9 +71,6 @@ def get_examples_haploid(self, ts, include_ancestors): return ref_panel, queries def get_examples_diploid(self, ts, include_ancestors): - # TODO Handle NONCOPY properly. - if include_ancestors is True: - raise NotImplementedError ref_panel = ts.genotype_matrix() num_sites = ref_panel.shape[0] # Take some haplotypes as queries from the reference panel. @@ -92,7 +89,10 @@ def get_examples_diploid(self, ts, include_ancestors): query_miss_most[:, 1:] = core.MISSING queries = [query_1, query_2, query_miss_last, query_miss_mid, query_miss_most] # Exclude the arbitrarily chosen queries from the reference panel. - ref_panel = ref_panel[:, 2:-2] + if include_ancestors: + ref_panel = self.get_ancestral_haplotypes(ts) + else: + ref_panel = ref_panel[:, 2:-2] return ref_panel, queries def get_examples_pars( diff --git a/tests/test_API.py b/tests/test_API.py index 80ad98c..016085e 100644 --- a/tests/test_API.py +++ b/tests/test_API.py @@ -176,7 +176,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): self.assertAllClose(ll_vs, ll) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() self.verify( @@ -186,7 +186,7 @@ def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n6() self.verify( @@ -196,7 +196,7 @@ def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n8() self.verify( @@ -206,7 +206,7 @@ def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n8_high_recomb() self.verify( @@ -216,7 +216,7 @@ def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n16() self.verify( @@ -226,7 +226,7 @@ def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_larger(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_custom_pars( ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5 @@ -367,7 +367,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): self.assertAllClose(phased_path_vs, path) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() self.verify( @@ -377,7 +377,7 @@ def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n6() self.verify( @@ -387,7 +387,7 @@ def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n8() self.verify( @@ -397,7 +397,7 @@ def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n8_high_recomb() self.verify( @@ -407,7 +407,7 @@ def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n16() self.verify( @@ -417,7 +417,7 @@ def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_larger(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_custom_pars( ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5 diff --git a/tests/test_non_tree.py b/tests/test_non_tree.py index 4a5ebcf..6f653a1 100644 --- a/tests/test_non_tree.py +++ b/tests/test_non_tree.py @@ -113,9 +113,11 @@ def verify( ): G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs) s = core.convert_haplotypes_to_unphased_genotypes(query) + F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True) B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r) self.assertAllClose(np.sum(F_vs * B_vs, (1, 2)), np.ones(m)) + F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( n, m, G_vs, s, e_vs, r, norm=True ) @@ -155,7 +157,7 @@ def verify( self.assertAllClose(ll_vs, ll_tmp) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) @pytest.mark.parametrize("normalise", [True, False]) def test_ts_simple_n10_no_recomb( self, scale_mutation_rate, include_ancestors, normalise @@ -173,7 +175,7 @@ def test_ts_simple_n10_no_recomb( ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) @pytest.mark.parametrize("normalise", [True, False]) def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors, normalise): ts = self.get_ts_simple_n6() @@ -187,7 +189,7 @@ def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors, normalise): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) @pytest.mark.parametrize("normalise", [True, False]) def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors, normalise): ts = self.get_ts_simple_n8() @@ -201,7 +203,7 @@ def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors, normalise): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) @pytest.mark.parametrize("normalise", [True, False]) def test_ts_simple_n8_high_recomb( self, scale_mutation_rate, include_ancestors, normalise @@ -217,7 +219,7 @@ def test_ts_simple_n8_high_recomb( ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) @pytest.mark.parametrize("normalise", [True, False]) def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors, normalise): ts = self.get_ts_simple_n16() @@ -231,7 +233,7 @@ def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors, normalise): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) @pytest.mark.parametrize("normalise", [True, False]) def test_ts_larger(self, scale_mutation_rate, include_ancestors, normalise): ts = self.get_ts_custom_pars( @@ -395,30 +397,13 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): ): G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs) s = core.convert_haplotypes_to_unphased_genotypes(query) - V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_naive(n, m, G_vs, s, e_vs, r) - path_vs = vd.backwards_viterbi_dip(m, V_vs[m - 1, :, :], P_vs) + + V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(n, m, G_vs, s, e_vs, r) + path_vs = vd.backwards_viterbi_dip(m, V_vs, P_vs) phased_path_vs = vd.get_phased_path(n, path_vs) path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, s, e_vs, r) self.assertAllClose(ll_vs, path_ll_vs) - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) - - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_low_mem( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) - ( V_tmp, V_argmaxes_tmp, @@ -442,17 +427,37 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): self.assertAllClose(ll_tmp, path_ll_tmp) self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) + num_ref_haps = H_vs.shape[1] + if num_ref_haps <= 100: + V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive( + n, m, G_vs, s, e_vs, r + ) + path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem( + n, m, G_vs, s, e_vs, r + ) + path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec( + n, m, G_vs, s, e_vs, r + ) + path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() self.verify( @@ -462,7 +467,7 @@ def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n6() self.verify( @@ -472,7 +477,7 @@ def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n8() self.verify( @@ -482,7 +487,7 @@ def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n8_high_recomb() self.verify( @@ -492,7 +497,7 @@ def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n16() self.verify( @@ -502,7 +507,7 @@ def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors): ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_larger(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_custom_pars( ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5