Skip to content

Commit

Permalink
Scale mutation rate for the diploid case and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 1, 2024
1 parent 8741f63 commit ab47840
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 60 deletions.
5 changes: 4 additions & 1 deletion lshmm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ def set_emission_probabilities(
)
else:
emission_probs = core.get_emission_matrix_diploid(
mu=prob_mutation, num_sites=num_sites
mu=prob_mutation,
num_sites=num_sites,
num_alleles=num_alleles,
scale_mutation_rate=scale_mutation_rate,
)

return emission_probs
Expand Down
38 changes: 29 additions & 9 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,35 @@ def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate)
return e


def get_emission_matrix_diploid(mu, num_sites):
e = np.zeros((num_sites, 8))
e[:, EQUAL_BOTH_HOM] = (1 - mu) ** 2
e[:, UNEQUAL_BOTH_HOM] = mu**2
e[:, BOTH_HET] = (1 - mu) ** 2 + mu**2
e[:, REF_HOM_OBS_HET] = 2 * mu * (1 - mu)
e[:, REF_HET_OBS_HOM] = mu * (1 - mu)
e[:, MISSING_INDEX] = 1
return e
def get_emission_matrix_diploid(mu, num_sites, num_alleles, scale_mutation_rate):
assert len(mu) == len(
num_alleles
), "Arrays of mutation probability and number of alleles are unequal in length."
if isinstance(mu, float):
mu = np.zeros(num_sites, dtype=np.float64) + mu
prob_mutation = np.zeros(num_sites, dtype=np.float64) - np.inf
prob_no_mutation = np.zeros(num_sites, dtype=np.float64) - np.inf
emission_matrix = np.zeros((num_sites, 8), dtype=np.float64) - np.inf
for i in range(num_sites):
if num_alleles[i] == 1:
# Set probabilities at invariant sites.
prob_mutation[i] = 0
prob_no_mutation[i] = 1
else:
if scale_mutation_rate:
prob_mutation[i] = mu[i]
prob_no_mutation[i] = 1 - (num_alleles[i] - 1) * mu[i]
else:
prob_mutation[i] = mu[i] / (num_alleles[i] - 1)
prob_no_mutation[i] = 1 - mu[i]
for i in range(num_sites):
emission_matrix[i, EQUAL_BOTH_HOM] = prob_no_mutation[i] ** 2
emission_matrix[i, UNEQUAL_BOTH_HOM] = prob_mutation[i] ** 2
emission_matrix[i, BOTH_HET] = prob_no_mutation[i] ** 2 + prob_mutation[i] ** 2
emission_matrix[i, REF_HOM_OBS_HET] = 2 * prob_mutation[i] * prob_no_mutation[i]
emission_matrix[i, REF_HET_OBS_HOM] = prob_mutation[i] * prob_no_mutation[i]
emission_matrix[i, MISSING_INDEX] = 1
return emission_matrix


@jit.numba_njit
Expand Down
12 changes: 10 additions & 2 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,19 @@ def get_examples_pars(
num_alleles = core.get_num_alleles(H, query)
if ploidy == 1:
e = core.get_emission_matrix_haploid(
mu, m, num_alleles, scale_mutation_rate
mu=mu,
num_sites=m,
num_alleles=num_alleles,
scale_mutation_rate=scale_mutation_rate,
)
yield n, m, H, query, e, r, mu
else:
e = core.get_emission_matrix_diploid(mu, m)
e = core.get_emission_matrix_diploid(
mu=mu,
num_sites=m,
num_alleles=num_alleles,
scale_mutation_rate=scale_mutation_rate,
)
yield n, m, H, query, e, r, mu

# Prepare simple example datasets.
Expand Down
103 changes: 79 additions & 24 deletions tests/test_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs)
s = core.convert_haplotypes_to_unphased_genotypes(query)
num_alleles = core.get_num_alleles(H_vs, query)

F_vs, c_vs, ll_vs = fbd.forward_ls_dip_loop(
n=n,
m=m,
Expand Down Expand Up @@ -148,39 +149,65 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
self.assertAllClose(B, B_vs)
self.assertAllClose(ll_vs, ll)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n10_no_recomb(self, include_ancestors):
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n6(self, include_ancestors):
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8(self, include_ancestors):
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n16(self, include_ancestors):
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_larger(self, include_ancestors):
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(
ts,
scale_mutation_rate=True,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

Expand Down Expand Up @@ -264,6 +291,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs)
s = core.convert_haplotypes_to_unphased_genotypes(query)
num_alleles = core.get_num_alleles(H_vs, query)

V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(
n=n,
m=m,
Expand All @@ -282,41 +310,68 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
prob_mutation=mu,
scale_mutation_rate=scale_mutation_rate,
)

self.assertAllClose(ll_vs, ll)
self.assertAllClose(phased_path_vs, path)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n10_no_recomb(self, include_ancestors):
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n6(self, include_ancestors):
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8(self, include_ancestors):
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n16(self, include_ancestors):
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_larger(self, include_ancestors):
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(
ts,
scale_mutation_rate=True,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
Loading

0 comments on commit ab47840

Please sign in to comment.