Skip to content

Commit

Permalink
Pre-release checks
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 18, 2024
1 parent 014a249 commit b1e215b
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 81 deletions.
18 changes: 9 additions & 9 deletions lshmm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def check_inputs(
reference_panel,
query,
prob_recombination,
prob_mutation=None,
scale_mutation_rate=None,
prob_mutation,
scale_mutation_rate,
):
"""
Check that the input data and parameters are valid, and return basic info
Expand All @@ -50,8 +50,8 @@ def check_inputs(
:param numpy.ndarray reference_panel: An array of size (m, n) or (m, n, n).
:param numpy.ndarray query: An array of size (k, m).
:param numpy.ndarray prob_recombination: Recombination probability.
:param numpy.ndarray prob_mutation: Mutation probability. If None (default), set as per Li & Stephens (2003).
:param bool scale_mutation_rate: Scale mutation rate if True (default).
:param numpy.ndarray prob_mutation: Mutation probability.
:param bool scale_mutation_rate: Scale mutation rate.
:return: Number of reference haplotypes, number of sites, ploidy
:rtype: tuple
"""
Expand All @@ -60,7 +60,7 @@ def check_inputs(

# Check the reference panel.
if not len(reference_panel.shape) in (2, 3):
err_msg = "Reference panel array must have 2 or 3 dimensions."
err_msg = "Reference panel array has incorrect dimensions."
raise ValueError(err_msg)

if len(reference_panel.shape) == 2:
Expand Down Expand Up @@ -129,7 +129,7 @@ def set_emission_probabilities(
scale_mutation_rate,
):
if isinstance(prob_mutation, float):
prob_mutation = prob_mutation * np.ones(num_sites)
prob_mutation = np.zeros(num_sites) + prob_mutation

if ploidy == 1:
emission_probs = core.get_emission_matrix_haploid(
Expand Down Expand Up @@ -159,7 +159,7 @@ def forwards(
scale_mutation_rate=None,
normalise=None,
):
"""Run the forwards algorithm on haplotype or unphased genotype data."""
"""Run the forwards algorithm on haploid or diploid genotype data."""
if scale_mutation_rate is None:
scale_mutation_rate = True

Expand Down Expand Up @@ -217,7 +217,7 @@ def backwards(
prob_mutation=None,
scale_mutation_rate=None,
):
"""Run the backwards algorithm on haplotype or unphased genotype data."""
"""Run the backwards algorithm on haploid or diploid genotype data."""
if scale_mutation_rate is None:
scale_mutation_rate = True

Expand Down Expand Up @@ -267,7 +267,7 @@ def viterbi(
prob_mutation=None,
scale_mutation_rate=None,
):
"""Run the Viterbi algorithm on haplotype or unphased genotype data."""
"""Run the Viterbi algorithm on haploid or diploid genotype data."""
if scale_mutation_rate is None:
scale_mutation_rate = True

Expand Down
11 changes: 9 additions & 2 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def check_genotype_matrix(genotype_matrix, num_sample_haps):
m = number of sites.
n = number of haplotypes (sample and ancestor) in the reference panel.
The maximum value is equal to (2n - 1), where n is the number of sample haplotypes
The maximum value is equal to (2*k - 1), where k is the number of sample haplotypes
in the genotype matrix, when a marginal tree is fully binary.
:param numpy.ndarray genotype_matrix: An array containing the reference haplotypes.
Expand Down Expand Up @@ -422,7 +422,14 @@ def get_index_in_emission_matrix_diploid(ref_genotype, query_genotype):

# Miscellaneous functions.
def estimate_mutation_probability(num_haps):
"""Return the mutation probability as defined by A2 and A3 in Li & Stephens (2003)."""
"""
Return an estimate of mutation probability based on the number of haplotypes
as defined by the equations A2 and A3 in Li & Stephens (2003).
:param int num_haps: Number of haplotypes.
:return: Estimate of mutation probability.
:rtype: float
"""
if num_haps < 3:
err_msg = "Number of haplotypes must be at least 3."
raise ValueError(err_msg)
Expand Down
5 changes: 1 addition & 4 deletions lshmm/fb_diploid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""
Various implementations of the Li & Stephens forwards-backwards algorithm on diploid genotype data,
where the data is structured as variants x samples x samples.
"""
"""Implementations of the Li & Stephens forwards-backwards algorithm on diploid genotype data."""

import numpy as np

Expand Down
9 changes: 3 additions & 6 deletions lshmm/fb_haploid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""
Various implementations of the Li & Stephens forwards-backwards algorithm on haploid genotype data,
where the data is structured as variants x samples.
"""
"""Implementations of the Li & Stephens forwards-backwards algorithm on haploid genotype data."""

import numpy as np

Expand All @@ -12,7 +9,7 @@
@jit.numba_njit
def forwards_ls_hap(n, m, H, s, e, r, norm=True):
"""
A matrix-based implementation using Numpy vectorisation.
A matrix-based implementation using Numpy.
This is exposed via the API.
"""
Expand Down Expand Up @@ -84,7 +81,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
@jit.numba_njit
def backwards_ls_hap(n, m, H, s, e, c, r):
"""
A matrix-based implementation using Numpy vectorisation.
A matrix-based implementation using Numpy.
This is exposed via the API.
"""
Expand Down
7 changes: 2 additions & 5 deletions lshmm/vit_diploid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""
Various implementations of the Li & Stephens Viterbi algorithm on diploid genotype data,
where the data is structured as variants x samples x samples.
"""
"""Implementations of the Li & Stephens Viterbi algorithm on diploid genotype data."""

import numpy as np

Expand Down Expand Up @@ -461,7 +458,7 @@ def get_phased_path(n, path):
@jit.numba_njit
def path_ll_dip(n, m, G, phased_path, s, e, r):
"""
Evaluate log-likelihood path through a reference panel which results in sequence.
Evaluate the log-likelihood of a path through a reference panel resulting in a query.
This is exposed via the API.
"""
Expand Down
2 changes: 1 addition & 1 deletion lshmm/vit_haploid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Implementations of the Li & Stephens Viterbi algorithm on haploid data."""
"""Implementations of the Li & Stephens Viterbi algorithm on haploid genotype data."""

import numpy as np

Expand Down
77 changes: 37 additions & 40 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,14 @@ def get_examples_pars(
include_extreme_rates,
seed=42,
):
"""Returns an iterator over combinations of examples and parameters."""
"""Return an iterator over combinations of examples and parameters."""
assert ploidy in [1, 2]
assert scale_mutation_rate in [True, False]
assert include_ancestors in [True, False]
assert include_extreme_rates in [True, False]

np.random.seed(seed)

if ploidy == 1:
H, queries = self.get_examples_haploid(ts, include_ancestors)
else:
Expand All @@ -156,7 +157,7 @@ def get_examples_pars(
for i in range(len(r_s)):
r_s[i][0] = 0

mus = [
mu_s = [
np.zeros(m) + 0.01, # Equal recombination and mutation
np.random.rand(m) * 0.2, # Random
1e-5 * (np.random.rand(m) + 0.5) / 2,
Expand All @@ -166,17 +167,18 @@ def get_examples_pars(
if include_extreme_rates:
r_s.append(np.zeros(m) + 0.2)
r_s.append(np.zeros(m) + 1e-6)
mus.append(np.zeros(m) + 0.2)
mus.append(np.zeros(m) + 1e-6)
mu_s.append(np.zeros(m) + 0.2)
mu_s.append(np.zeros(m) + 1e-6)

for query, r, mu in itertools.product(queries, r_s, mus):
for query, r, mu in itertools.product(queries, r_s, mu_s):
# Must be calculated from the genotype matrix,
# because we can now get back mutations that
# result in the number of alleles being higher
# than the number of alleles in the reference panel.
num_alleles = core.get_num_alleles(H, query)
prob_mutation = mu
if prob_mutation is None:
# Note that n is the number of haplotypes, including ancestors.
prob_mutation = np.zeros(m) + core.estimate_mutation_probability(n)
if ploidy == 1:
e = core.get_emission_matrix_haploid(
Expand All @@ -185,15 +187,14 @@ def get_examples_pars(
num_alleles=num_alleles,
scale_mutation_rate=scale_mutation_rate,
)
yield n, m, H, query, e, r, mu
else:
e = core.get_emission_matrix_diploid(
mu=prob_mutation,
num_sites=m,
num_alleles=num_alleles,
scale_mutation_rate=scale_mutation_rate,
)
yield n, m, H, query, e, r, mu
yield n, m, H, query, e, r, mu

# Prepare simple example datasets.
def get_ts_simple_n10_no_recomb(self, seed=42):
Expand Down Expand Up @@ -249,7 +250,7 @@ def get_ts_simple_n16(self, seed=42):

def get_ts_custom_pars(self, ref_panel_size, length, mean_r, mean_mu, seed=42):
ts = msprime.simulate(
ref_panel_size + 1,
ref_panel_size,
length=length,
recombination_rate=mean_r,
mutation_rate=mean_mu,
Expand All @@ -259,47 +260,44 @@ def get_ts_custom_pars(self, ref_panel_size, length, mean_r, mean_mu, seed=42):

# Prepare example datasets with multiallelic sites.
def get_ts_multiallelic_n10_no_recomb(self, seed=42):
ts = msprime.sim_ancestry(
samples=10,
recombination_rate=0,
sequence_length=10,
population_size=1e4,
random_seed=seed,
)
ts = msprime.sim_mutations(
ts,
msprime.sim_ancestry(
samples=10,
recombination_rate=0,
sequence_length=10,
population_size=1e4,
random_seed=seed,
),
rate=1e-5,
random_seed=seed,
)
assert ts.num_sites > 3
return ts

def get_ts_multiallelic_n6(self, seed=42):
ts = msprime.sim_ancestry(
samples=6,
recombination_rate=1e-4,
sequence_length=40,
population_size=1e4,
random_seed=seed,
)
ts = msprime.sim_mutations(
ts,
msprime.sim_ancestry(
samples=6,
recombination_rate=1e-4,
sequence_length=40,
population_size=1e4,
random_seed=seed,
),
rate=1e-3,
random_seed=seed,
)
assert ts.num_sites > 5
return ts

def get_ts_multiallelic_n8(self, seed=42):
ts = msprime.sim_ancestry(
samples=8,
recombination_rate=1e-4,
sequence_length=20,
population_size=1e4,
random_seed=seed,
)
ts = msprime.sim_mutations(
ts,
msprime.sim_ancestry(
samples=8,
recombination_rate=1e-4,
sequence_length=20,
population_size=1e4,
random_seed=seed,
),
rate=1e-4,
random_seed=seed,
)
Expand All @@ -308,15 +306,14 @@ def get_ts_multiallelic_n8(self, seed=42):
return ts

def get_ts_multiallelic_n16(self, seed=42):
ts = msprime.sim_ancestry(
samples=16,
recombination_rate=1e-2,
sequence_length=20,
population_size=1e4,
random_seed=seed,
)
ts = msprime.sim_mutations(
ts,
msprime.sim_ancestry(
samples=16,
recombination_rate=1e-2,
sequence_length=20,
population_size=1e4,
random_seed=seed,
),
rate=1e-4,
random_seed=seed,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_api_fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
prob_mutation=mu,
scale_mutation_rate=scale_mutation_rate,
)
self.assertAllClose(F, F_vs)
self.assertAllClose(B, B_vs)
self.assertAllClose(F_vs, F)
self.assertAllClose(B_vs, B)
self.assertAllClose(ll_vs, ll)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
Expand Down
24 changes: 20 additions & 4 deletions tests/test_api_fb_haploid_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,38 @@ def test_ts_multiallelic_n10_no_recomb(
self, scale_mutation_rate, include_ancestors
):
ts = self.get_ts_multiallelic_n10_no_recomb()
self.verify(ts, scale_mutation_rate, include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic_n6()
self.verify(ts, scale_mutation_rate, include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic_n8()
self.verify(ts, scale_mutation_rate, include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic_n16()
self.verify(ts, scale_mutation_rate, include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
2 changes: 1 addition & 1 deletion tests/test_api_vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=46, length=1e5, mean_r=1e-5, mean_mu=1e-5
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(
ts,
Expand Down
Loading

0 comments on commit b1e215b

Please sign in to comment.