diff --git a/lshmm/api.py b/lshmm/api.py index b5c0e10..7b32f38 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -30,8 +30,8 @@ def check_inputs( reference_panel, query, prob_recombination, - prob_mutation=None, - scale_mutation_rate=None, + prob_mutation, + scale_mutation_rate, ): """ Check that the input data and parameters are valid, and return basic info @@ -50,8 +50,8 @@ def check_inputs( :param numpy.ndarray reference_panel: An array of size (m, n) or (m, n, n). :param numpy.ndarray query: An array of size (k, m). :param numpy.ndarray prob_recombination: Recombination probability. - :param numpy.ndarray prob_mutation: Mutation probability. If None (default), set as per Li & Stephens (2003). - :param bool scale_mutation_rate: Scale mutation rate if True (default). + :param numpy.ndarray prob_mutation: Mutation probability. + :param bool scale_mutation_rate: Scale mutation rate. :return: Number of reference haplotypes, number of sites, ploidy :rtype: tuple """ @@ -60,7 +60,7 @@ def check_inputs( # Check the reference panel. if not len(reference_panel.shape) in (2, 3): - err_msg = "Reference panel array must have 2 or 3 dimensions." + err_msg = "Reference panel array has incorrect dimensions." raise ValueError(err_msg) if len(reference_panel.shape) == 2: @@ -129,7 +129,7 @@ def set_emission_probabilities( scale_mutation_rate, ): if isinstance(prob_mutation, float): - prob_mutation = prob_mutation * np.ones(num_sites) + prob_mutation = np.zeros(num_sites) + prob_mutation if ploidy == 1: emission_probs = core.get_emission_matrix_haploid( @@ -159,7 +159,7 @@ def forwards( scale_mutation_rate=None, normalise=None, ): - """Run the forwards algorithm on haplotype or unphased genotype data.""" + """Run the forwards algorithm on haploid or diploid genotype data.""" if scale_mutation_rate is None: scale_mutation_rate = True @@ -217,7 +217,7 @@ def backwards( prob_mutation=None, scale_mutation_rate=None, ): - """Run the backwards algorithm on haplotype or unphased genotype data.""" + """Run the backwards algorithm on haploid or diploid genotype data.""" if scale_mutation_rate is None: scale_mutation_rate = True @@ -267,7 +267,7 @@ def viterbi( prob_mutation=None, scale_mutation_rate=None, ): - """Run the Viterbi algorithm on haplotype or unphased genotype data.""" + """Run the Viterbi algorithm on haploid or diploid genotype data.""" if scale_mutation_rate is None: scale_mutation_rate = True diff --git a/lshmm/core.py b/lshmm/core.py index f6acf75..b8bd86d 100644 --- a/lshmm/core.py +++ b/lshmm/core.py @@ -135,7 +135,7 @@ def check_genotype_matrix(genotype_matrix, num_sample_haps): m = number of sites. n = number of haplotypes (sample and ancestor) in the reference panel. - The maximum value is equal to (2n - 1), where n is the number of sample haplotypes + The maximum value is equal to (2*k - 1), where k is the number of sample haplotypes in the genotype matrix, when a marginal tree is fully binary. :param numpy.ndarray genotype_matrix: An array containing the reference haplotypes. @@ -422,7 +422,14 @@ def get_index_in_emission_matrix_diploid(ref_genotype, query_genotype): # Miscellaneous functions. def estimate_mutation_probability(num_haps): - """Return the mutation probability as defined by A2 and A3 in Li & Stephens (2003).""" + """ + Return an estimate of mutation probability based on the number of haplotypes + as defined by the equations A2 and A3 in Li & Stephens (2003). + + :param int num_haps: Number of haplotypes. + :return: Estimate of mutation probability. + :rtype: float + """ if num_haps < 3: err_msg = "Number of haplotypes must be at least 3." raise ValueError(err_msg) diff --git a/lshmm/fb_diploid.py b/lshmm/fb_diploid.py index 406e1ac..a7e33ea 100644 --- a/lshmm/fb_diploid.py +++ b/lshmm/fb_diploid.py @@ -1,7 +1,4 @@ -""" -Various implementations of the Li & Stephens forwards-backwards algorithm on diploid genotype data, -where the data is structured as variants x samples x samples. -""" +"""Implementations of the Li & Stephens forwards-backwards algorithm on diploid genotype data.""" import numpy as np diff --git a/lshmm/fb_haploid.py b/lshmm/fb_haploid.py index 3b2a9d7..7ae8460 100644 --- a/lshmm/fb_haploid.py +++ b/lshmm/fb_haploid.py @@ -1,7 +1,4 @@ -""" -Various implementations of the Li & Stephens forwards-backwards algorithm on haploid genotype data, -where the data is structured as variants x samples. -""" +"""Implementations of the Li & Stephens forwards-backwards algorithm on haploid genotype data.""" import numpy as np @@ -12,7 +9,7 @@ @jit.numba_njit def forwards_ls_hap(n, m, H, s, e, r, norm=True): """ - A matrix-based implementation using Numpy vectorisation. + A matrix-based implementation using Numpy. This is exposed via the API. """ @@ -84,7 +81,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): @jit.numba_njit def backwards_ls_hap(n, m, H, s, e, c, r): """ - A matrix-based implementation using Numpy vectorisation. + A matrix-based implementation using Numpy. This is exposed via the API. """ diff --git a/lshmm/vit_diploid.py b/lshmm/vit_diploid.py index 7c3a8a1..48e1d01 100644 --- a/lshmm/vit_diploid.py +++ b/lshmm/vit_diploid.py @@ -1,7 +1,4 @@ -""" -Various implementations of the Li & Stephens Viterbi algorithm on diploid genotype data, -where the data is structured as variants x samples x samples. -""" +"""Implementations of the Li & Stephens Viterbi algorithm on diploid genotype data.""" import numpy as np @@ -461,7 +458,7 @@ def get_phased_path(n, path): @jit.numba_njit def path_ll_dip(n, m, G, phased_path, s, e, r): """ - Evaluate log-likelihood path through a reference panel which results in sequence. + Evaluate the log-likelihood of a path through a reference panel resulting in a query. This is exposed via the API. """ diff --git a/lshmm/vit_haploid.py b/lshmm/vit_haploid.py index 8c83314..87dbb97 100644 --- a/lshmm/vit_haploid.py +++ b/lshmm/vit_haploid.py @@ -1,4 +1,4 @@ -"""Implementations of the Li & Stephens Viterbi algorithm on haploid data.""" +"""Implementations of the Li & Stephens Viterbi algorithm on haploid genotype data.""" import numpy as np diff --git a/tests/lsbase.py b/tests/lsbase.py index e9acef1..d13c0b0 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -129,13 +129,14 @@ def get_examples_pars( include_extreme_rates, seed=42, ): - """Returns an iterator over combinations of examples and parameters.""" + """Return an iterator over combinations of examples and parameters.""" assert ploidy in [1, 2] assert scale_mutation_rate in [True, False] assert include_ancestors in [True, False] assert include_extreme_rates in [True, False] np.random.seed(seed) + if ploidy == 1: H, queries = self.get_examples_haploid(ts, include_ancestors) else: @@ -156,7 +157,7 @@ def get_examples_pars( for i in range(len(r_s)): r_s[i][0] = 0 - mus = [ + mu_s = [ np.zeros(m) + 0.01, # Equal recombination and mutation np.random.rand(m) * 0.2, # Random 1e-5 * (np.random.rand(m) + 0.5) / 2, @@ -166,10 +167,10 @@ def get_examples_pars( if include_extreme_rates: r_s.append(np.zeros(m) + 0.2) r_s.append(np.zeros(m) + 1e-6) - mus.append(np.zeros(m) + 0.2) - mus.append(np.zeros(m) + 1e-6) + mu_s.append(np.zeros(m) + 0.2) + mu_s.append(np.zeros(m) + 1e-6) - for query, r, mu in itertools.product(queries, r_s, mus): + for query, r, mu in itertools.product(queries, r_s, mu_s): # Must be calculated from the genotype matrix, # because we can now get back mutations that # result in the number of alleles being higher @@ -177,6 +178,7 @@ def get_examples_pars( num_alleles = core.get_num_alleles(H, query) prob_mutation = mu if prob_mutation is None: + # Note that n is the number of haplotypes, including ancestors. prob_mutation = np.zeros(m) + core.estimate_mutation_probability(n) if ploidy == 1: e = core.get_emission_matrix_haploid( @@ -185,7 +187,6 @@ def get_examples_pars( num_alleles=num_alleles, scale_mutation_rate=scale_mutation_rate, ) - yield n, m, H, query, e, r, mu else: e = core.get_emission_matrix_diploid( mu=prob_mutation, @@ -193,7 +194,7 @@ def get_examples_pars( num_alleles=num_alleles, scale_mutation_rate=scale_mutation_rate, ) - yield n, m, H, query, e, r, mu + yield n, m, H, query, e, r, mu # Prepare simple example datasets. def get_ts_simple_n10_no_recomb(self, seed=42): @@ -249,7 +250,7 @@ def get_ts_simple_n16(self, seed=42): def get_ts_custom_pars(self, ref_panel_size, length, mean_r, mean_mu, seed=42): ts = msprime.simulate( - ref_panel_size + 1, + ref_panel_size, length=length, recombination_rate=mean_r, mutation_rate=mean_mu, @@ -259,15 +260,14 @@ def get_ts_custom_pars(self, ref_panel_size, length, mean_r, mean_mu, seed=42): # Prepare example datasets with multiallelic sites. def get_ts_multiallelic_n10_no_recomb(self, seed=42): - ts = msprime.sim_ancestry( - samples=10, - recombination_rate=0, - sequence_length=10, - population_size=1e4, - random_seed=seed, - ) ts = msprime.sim_mutations( - ts, + msprime.sim_ancestry( + samples=10, + recombination_rate=0, + sequence_length=10, + population_size=1e4, + random_seed=seed, + ), rate=1e-5, random_seed=seed, ) @@ -275,15 +275,14 @@ def get_ts_multiallelic_n10_no_recomb(self, seed=42): return ts def get_ts_multiallelic_n6(self, seed=42): - ts = msprime.sim_ancestry( - samples=6, - recombination_rate=1e-4, - sequence_length=40, - population_size=1e4, - random_seed=seed, - ) ts = msprime.sim_mutations( - ts, + msprime.sim_ancestry( + samples=6, + recombination_rate=1e-4, + sequence_length=40, + population_size=1e4, + random_seed=seed, + ), rate=1e-3, random_seed=seed, ) @@ -291,15 +290,14 @@ def get_ts_multiallelic_n6(self, seed=42): return ts def get_ts_multiallelic_n8(self, seed=42): - ts = msprime.sim_ancestry( - samples=8, - recombination_rate=1e-4, - sequence_length=20, - population_size=1e4, - random_seed=seed, - ) ts = msprime.sim_mutations( - ts, + msprime.sim_ancestry( + samples=8, + recombination_rate=1e-4, + sequence_length=20, + population_size=1e4, + random_seed=seed, + ), rate=1e-4, random_seed=seed, ) @@ -308,15 +306,14 @@ def get_ts_multiallelic_n8(self, seed=42): return ts def get_ts_multiallelic_n16(self, seed=42): - ts = msprime.sim_ancestry( - samples=16, - recombination_rate=1e-2, - sequence_length=20, - population_size=1e4, - random_seed=seed, - ) ts = msprime.sim_mutations( - ts, + msprime.sim_ancestry( + samples=16, + recombination_rate=1e-2, + sequence_length=20, + population_size=1e4, + random_seed=seed, + ), rate=1e-4, random_seed=seed, ) diff --git a/tests/test_api_fb_haploid.py b/tests/test_api_fb_haploid.py index 536e896..2a24637 100644 --- a/tests/test_api_fb_haploid.py +++ b/tests/test_api_fb_haploid.py @@ -51,8 +51,8 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): prob_mutation=mu, scale_mutation_rate=scale_mutation_rate, ) - self.assertAllClose(F, F_vs) - self.assertAllClose(B, B_vs) + self.assertAllClose(F_vs, F) + self.assertAllClose(B_vs, B) self.assertAllClose(ll_vs, ll) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) diff --git a/tests/test_api_fb_haploid_multi.py b/tests/test_api_fb_haploid_multi.py index 6eb416c..a50b299 100644 --- a/tests/test_api_fb_haploid_multi.py +++ b/tests/test_api_fb_haploid_multi.py @@ -61,22 +61,38 @@ def test_ts_multiallelic_n10_no_recomb( self, scale_mutation_rate, include_ancestors ): ts = self.get_ts_multiallelic_n10_no_recomb() - self.verify(ts, scale_mutation_rate, include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_multiallelic_n6(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_multiallelic_n6() - self.verify(ts, scale_mutation_rate, include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_multiallelic_n8(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_multiallelic_n8() - self.verify(ts, scale_mutation_rate, include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_multiallelic_n16(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_multiallelic_n16() - self.verify(ts, scale_mutation_rate, include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) diff --git a/tests/test_api_vit_haploid.py b/tests/test_api_vit_haploid.py index dd78bfc..aa5c9fe 100644 --- a/tests/test_api_vit_haploid.py +++ b/tests/test_api_vit_haploid.py @@ -90,7 +90,7 @@ def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors): @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_larger(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_custom_pars( - ref_panel_size=46, length=1e5, mean_r=1e-5, mean_mu=1e-5 + ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5 ) self.verify( ts, diff --git a/tests/test_api_vit_haploid_multi.py b/tests/test_api_vit_haploid_multi.py index 15bbfde..037dc0d 100644 --- a/tests/test_api_vit_haploid_multi.py +++ b/tests/test_api_vit_haploid_multi.py @@ -17,9 +17,14 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): ): num_alleles = core.get_num_alleles(H_vs, s) V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling( - n, m, H_vs, s, e_vs, r + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, ) - path_vs = vh.backwards_viterbi_hap(m, V_vs, P_vs) + path_vs = vh.backwards_viterbi_hap(m=m, V_last=V_vs, P=P_vs) path_ll_hap = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) path, ll = ls.viterbi( reference_panel=H_vs, @@ -39,22 +44,38 @@ def test_ts_multiallelic_n10_no_recomb( self, scale_mutation_rate, include_ancestors ): ts = self.get_ts_multiallelic_n10_no_recomb() - self.verify(ts, scale_mutation_rate, include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_multiallelic_n6(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_multiallelic_n6() - self.verify(ts, scale_mutation_rate, include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_multiallelic_n8(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_multiallelic_n8() - self.verify(ts, scale_mutation_rate, include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) def test_ts_multiallelic_n16(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_multiallelic_n16() - self.verify(ts, scale_mutation_rate, include_ancestors) + self.verify( + ts, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + ) diff --git a/tests/test_nontree_fb_haploid.py b/tests/test_nontree_fb_haploid.py index 58c7623..dbd8eb0 100644 --- a/tests/test_nontree_fb_haploid.py +++ b/tests/test_nontree_fb_haploid.py @@ -81,10 +81,7 @@ def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors): @pytest.mark.parametrize("include_ancestors", [True, False]) def test_larger(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_custom_pars( - ref_panel_size=45, - length=1e5, - mean_r=1e-5, - mean_mu=1e-5, + ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5, ) self.verify( ts, diff --git a/tests/test_nontree_vit_haploid.py b/tests/test_nontree_vit_haploid.py index 837f59e..d53ca15 100644 --- a/tests/test_nontree_vit_haploid.py +++ b/tests/test_nontree_vit_haploid.py @@ -81,7 +81,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) - def test_ts_simple_n10_no_recombn(self, scale_mutation_rate, include_ancestors): + def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() self.verify( ts,