From 73ae3767f2d6a291cd63411694cc6ed491d219ad Mon Sep 17 00:00:00 2001 From: szhan Date: Fri, 21 Jun 2024 13:24:22 +0100 Subject: [PATCH] Change API functions to take query sequences of unphased genotypes --- lshmm/api.py | 56 ++++++++++++++++++++++--------- tests/lsbase.py | 24 +++++++++---- tests/test_api_fb_diploid.py | 11 +++--- tests/test_api_vit_diploid.py | 5 ++- tests/test_nontree_fb_diploid.py | 27 ++++++++------- tests/test_nontree_vit_diploid.py | 31 ++++++++++------- 6 files changed, 100 insertions(+), 54 deletions(-) diff --git a/lshmm/api.py b/lshmm/api.py index ef76468..f94b146 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -41,16 +41,21 @@ def check_inputs( The reference panel and query are arrays of size (m, n) and (k, m), respectively, where: m = number of sites. - n = number of samples in the reference panel (haplotypes, not individuals). - k = number of samples in the query (haplotypes, not individuals). + n = number of haplotypes (not individuals) in the reference panel. + k = number of haploid or diploid individuals in the query. - TODO: Support running on multiple queries. Currently, only k = 1 or 2 is supported. + In the haploid case, queries are (phased) haplotypes and can have multiallelic sites. + + In the diploid case queries are unphased genotypes (encoded as allele dosages). + Currently, only biallelic sites are supported. + + TODO: Support running on multiple queries. Currently, only k = 1 is supported. The mutation rate can be scaled according to the set of alleles that can be mutated to based on the number of distinct alleles at each site. - :param numpy.ndarray reference_panel: A panel of reference sequences. - :param numpy.ndarray query: A query sequence. + :param numpy.ndarray reference_panel: A panel of reference haplotypes. + :param numpy.ndarray query: A query (a haplotype or a sequence of allelic dosages). :param numpy.ndarray ploidy: Ploidy (only 1 or 2 are supported). :param numpy.ndarray prob_recombination: Recombination probability. :param numpy.ndarray prob_mutation: Mutation probability. @@ -74,14 +79,14 @@ def check_inputs( if ploidy == 2: if not np.all(np.isin(reference_panel, [0, 1, core.NONCOPY])): - err_msg = "Reference panel has illegal alleles. " - err_msg += "Only 0/1 encoding is supported in diploid mode." + err_msg = "Reference panel has not allowed in diploid mode. " + err_msg += "Only 0/1 biallelic encoding is supported." raise ValueError(err_msg) num_sites, num_ref_haps = reference_panel.shape # Check the queries. - if query.shape[0] != ploidy: + if query.shape[0] != 1: err_msg = "Query array has incorrect dimensions." raise ValueError(err_msg) @@ -94,9 +99,9 @@ def check_inputs( raise ValueError(err_msg) if ploidy == 2: - if not np.all(np.isin(query, [0, 1, core.MISSING])): - err_msg = "Query has illegal alleles. " - err_msg += "Only 0/1 encoding is supported in diploid mode." + if not np.all(np.isin(query, [0, 1, 2, core.MISSING])): + err_msg = "Query has states not allowed in diploid mode. " + err_msg += "Only 0/1/2 allele dosage encoding is supported." raise ValueError(err_msg) # Check the recombination probability. @@ -139,8 +144,30 @@ def check_inputs( err_msg = "Mutation probability is not a scalar or an array of expected length." raise ValueError(err_msg) + # Get the number of distinct alleles per site. + if ploidy == 1: + num_alleles = core.get_num_alleles(reference_panel, query) + else: + # TODO: This is a hack, because the ref. panel and query have different encodings. + # This needs to be overhauled when or before we deal with multiallelic sites. + # Also, this only works if we work with only biallelic sites. + query_unraveled = np.zeros((2, num_sites), dtype=np.int8) - np.inf + for i in range(num_sites): + if query[0, i] == 0: + query_unraveled[:, i] = np.array([0, 0]) + elif query[0, i] == 1: + query_unraveled[:, i] = np.array([0, 1]) + elif query[0, i] == 2: + query_unraveled[:, i] = np.array([1, 1]) + else: + query_unraveled[:, i] = np.array([core.MISSING, core.MISSING]) + num_alleles = core.get_num_alleles(reference_panel, query_unraveled) + + if not np.all(num_alleles > 1): + err_msg = "Some sites have less than two distinct alleles." + raise ValueError(err_msg) + # Calculate the emission probability matrix. - num_alleles = core.get_num_alleles(reference_panel, query) if ploidy == 1: emission_matrix = core.get_emission_matrix_haploid( mu=prob_mutation, @@ -168,12 +195,11 @@ def check_inputs( ref_panel_genotypes = core.convert_haplotypes_to_phased_genotypes( reference_panel ) - query_genotypes = core.convert_haplotypes_to_unphased_genotypes(query) return ( num_ref_haps, num_sites, - ref_panel_genotypes, - query_genotypes, + ref_panel_genotypes, # Only ref. panel is converted + query, emission_matrix, ) diff --git a/tests/lsbase.py b/tests/lsbase.py index 513e7ee..f54b75d 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -173,11 +173,13 @@ def get_examples_pars( # because we can now get back mutations that # result in the number of alleles being higher # than the number of alleles in the reference panel. - num_alleles = core.get_num_alleles(H, query) prob_mutation = mu if prob_mutation is None: # Note that n is the number of haplotypes, including ancestors. prob_mutation = np.zeros(m) + core.estimate_mutation_probability(n) + + num_alleles = core.get_num_alleles(H, query) + if ploidy == 1: e = core.get_emission_matrix_haploid( mu=prob_mutation, @@ -192,6 +194,9 @@ def get_examples_pars( num_alleles=num_alleles, scale_mutation_rate=scale_mutation_rate, ) + # In the diploid case, query is converted to unphased genotypes. + query = core.convert_haplotypes_to_unphased_genotypes(query) + yield n, m, H, query, e, r, mu # Prepare simple example datasets. @@ -204,11 +209,13 @@ def get_ts_simple_n10_no_recomb(self, seed=42): recombination_rate=0.0, random_seed=seed, ), + rate=0.3, model=msprime.BinaryMutationModel(), - rate=0.5, + discrete_genome=False, random_seed=seed, ) - assert ts.num_sites > 3 + assert ts.num_sites > 5 + assert ts.num_sites < 25 return ts def get_ts_simple(self, num_samples, seed=42): @@ -220,11 +227,13 @@ def get_ts_simple(self, num_samples, seed=42): recombination_rate=2.0, random_seed=seed, ), - rate=5.0, + rate=0.2, model=msprime.BinaryMutationModel(), + discrete_genome=False, random_seed=seed, ) assert ts.num_sites > 5 + assert ts.num_sites < 25 return ts def get_ts_simple_n8_high_recomb(self, seed=42): @@ -236,12 +245,14 @@ def get_ts_simple_n8_high_recomb(self, seed=42): recombination_rate=20.0, random_seed=seed, ), - rate=5.0, + rate=0.2, model=msprime.BinaryMutationModel(), + discrete_genome=False, random_seed=seed, ) assert ts.num_trees > 15 assert ts.num_sites > 5 + assert ts.num_sites < 25 return ts def get_ts_custom_pars(self, num_samples, seq_length, mean_r, mean_mu, seed=42): @@ -255,6 +266,7 @@ def get_ts_custom_pars(self, num_samples, seq_length, mean_r, mean_mu, seed=42): ), rate=mean_mu, model=msprime.BinaryMutationModel(), + discrete_genome=False, random_seed=seed, ) return ts @@ -302,8 +314,8 @@ def get_ts_multiallelic_n8(self, seed=42): rate=1e-4, random_seed=seed, ) - assert ts.num_sites > 5 assert ts.num_trees > 15 + assert ts.num_sites > 5 return ts def get_ts_multiallelic_n16(self, seed=42): diff --git a/tests/test_api_fb_diploid.py b/tests/test_api_fb_diploid.py index a092beb..c46864d 100644 --- a/tests/test_api_fb_diploid.py +++ b/tests/test_api_fb_diploid.py @@ -17,13 +17,12 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): include_extreme_rates=True, ): G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs) - s = core.convert_haplotypes_to_unphased_genotypes(query) F_vs, c_vs, ll_vs = fbd.forward_ls_dip_loop( n=n, m=m, G=G_vs, - s=s, + s=query, e=e_vs, r=r, norm=True, @@ -32,7 +31,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): n=n, m=m, G=G_vs, - s=s, + s=query, e=e_vs, c=c_vs, r=r, @@ -66,7 +65,7 @@ def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() self.verify(ts, scale_mutation_rate, include_ancestors) - @pytest.mark.parametrize("num_samples", [4, 8, 16]) + @pytest.mark.parametrize("num_samples", [8, 16]) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors): @@ -75,13 +74,13 @@ def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors): @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) - def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors): + def ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n8_high_recomb() self.verify(ts, scale_mutation_rate, include_ancestors) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) - def test_ts_larger(self, scale_mutation_rate, include_ancestors): + def ts_larger(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_custom_pars( num_samples=30, seq_length=1e5, mean_r=1e-5, mean_mu=1e-5 ) diff --git a/tests/test_api_vit_diploid.py b/tests/test_api_vit_diploid.py index ea6a47c..59147e1 100644 --- a/tests/test_api_vit_diploid.py +++ b/tests/test_api_vit_diploid.py @@ -17,13 +17,12 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): include_extreme_rates=True, ): G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs) - s = core.convert_haplotypes_to_unphased_genotypes(query) V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem( n=n, m=m, G=G_vs, - s=s, + s=query, e=e_vs, r=r, ) @@ -47,7 +46,7 @@ def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() self.verify(ts, scale_mutation_rate, include_ancestors) - @pytest.mark.parametrize("num_samples", [4, 8, 16]) + @pytest.mark.parametrize("num_samples", [8, 16]) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors): diff --git a/tests/test_nontree_fb_diploid.py b/tests/test_nontree_fb_diploid.py index 9e13c84..a76285c 100644 --- a/tests/test_nontree_fb_diploid.py +++ b/tests/test_nontree_fb_diploid.py @@ -25,45 +25,48 @@ def verify( include_extreme_rates=include_extreme_rates, ): G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs) - s = core.convert_haplotypes_to_unphased_genotypes(query) - F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True) - B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r) + F_vs, c_vs, ll_vs = fbd.forwards_ls_dip( + n, m, G_vs, query, e_vs, r, norm=True + ) + B_vs = fbd.backwards_ls_dip(n, m, G_vs, query, e_vs, c_vs, r) self.assertAllClose(np.sum(F_vs * B_vs, (1, 2)), np.ones(m)) F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( - n, m, G_vs, s, e_vs, r, norm=True + n, m, G_vs, query, e_vs, r, norm=True ) - B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) + B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, query, e_vs, c_tmp, r) self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m)) self.assertAllClose(ll_vs, ll_tmp) if not normalise: F_tmp, c_tmp, ll_tmp = fbd.forwards_ls_dip( - n, m, G_vs, s, e_vs, r, norm=False + n, m, G_vs, query, e_vs, r, norm=False ) if ll_tmp != -np.inf: - B_tmp = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_tmp, r) + B_tmp = fbd.backwards_ls_dip(n, m, G_vs, query, e_vs, c_tmp, r) self.assertAllClose( np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) ) self.assertAllClose(ll_vs, ll_tmp) F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( - n, m, G_vs, s, e_vs, r, norm=False + n, m, G_vs, query, e_vs, r, norm=False ) if ll_tmp != -np.inf: - B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) + B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, query, e_vs, c_tmp, r) self.assertAllClose( np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) ) self.assertAllClose(ll_vs, ll_tmp) F_tmp, ll_tmp = fbd.forward_ls_dip_starting_point( - n, m, G_vs, s, e_vs, r + n, m, G_vs, query, e_vs, r ) if ll_tmp != -np.inf: - B_tmp = fbd.backward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) + B_tmp = fbd.backward_ls_dip_starting_point( + n, m, G_vs, query, e_vs, r + ) self.assertAllClose( np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) ) @@ -87,7 +90,7 @@ def test_ts_simple_n10_no_recomb( include_extreme_rates=include_extreme_rates, ) - @pytest.mark.parametrize("num_samples", [4, 8, 16]) + @pytest.mark.parametrize("num_samples", [8, 16]) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) @pytest.mark.parametrize("normalise", [True, False]) diff --git a/tests/test_nontree_vit_diploid.py b/tests/test_nontree_vit_diploid.py index ea49df8..205f3a8 100644 --- a/tests/test_nontree_vit_diploid.py +++ b/tests/test_nontree_vit_diploid.py @@ -18,12 +18,13 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): include_extreme_rates=True, ): G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs) - s = core.convert_haplotypes_to_unphased_genotypes(query) - V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(n, m, G_vs, s, e_vs, r) + V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem( + n, m, G_vs, query, e_vs, r + ) path_vs = vd.backwards_viterbi_dip(m, V_vs, P_vs) phased_path_vs = vd.get_phased_path(n, path_vs) - path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, s, e_vs, r) + path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, query, e_vs, r) self.assertAllClose(ll_vs, path_ll_vs) ( @@ -34,7 +35,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): recombs_single, recombs_double, ll_tmp, - ) = vd.forwards_viterbi_dip_low_mem_no_pointer(n, m, G_vs, s, e_vs, r) + ) = vd.forwards_viterbi_dip_low_mem_no_pointer(n, m, G_vs, query, e_vs, r) path_tmp = vd.backwards_viterbi_dip_no_pointer( m, V_argmaxes_tmp, @@ -45,7 +46,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): V_tmp, ) phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, query, e_vs, r) self.assertAllClose(ll_tmp, path_ll_tmp) self.assertAllClose(ll_vs, ll_tmp) @@ -54,29 +55,35 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): if num_ref_haps <= MAX_NUM_REF_HAPS: # Run tests for the naive implementations. V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive( - n, m, G_vs, s, e_vs, r + n, m, G_vs, query, e_vs, r ) path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp) phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + path_ll_tmp = vd.path_ll_dip( + n, m, G_vs, phased_path_tmp, query, e_vs, r + ) self.assertAllClose(ll_tmp, path_ll_tmp) self.assertAllClose(ll_vs, ll_tmp) V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem( - n, m, G_vs, s, e_vs, r + n, m, G_vs, query, e_vs, r ) path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + path_ll_tmp = vd.path_ll_dip( + n, m, G_vs, phased_path_tmp, query, e_vs, r + ) self.assertAllClose(ll_tmp, path_ll_tmp) self.assertAllClose(ll_vs, ll_tmp) V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec( - n, m, G_vs, s, e_vs, r + n, m, G_vs, query, e_vs, r ) path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp) phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + path_ll_tmp = vd.path_ll_dip( + n, m, G_vs, phased_path_tmp, query, e_vs, r + ) self.assertAllClose(ll_tmp, path_ll_tmp) self.assertAllClose(ll_vs, ll_tmp) @@ -86,7 +93,7 @@ def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() self.verify(ts, scale_mutation_rate, include_ancestors) - @pytest.mark.parametrize("num_samples", [4, 8, 16]) + @pytest.mark.parametrize("num_samples", [8, 16]) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors):