Skip to content

Commit

Permalink
WIP diploid
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 19, 2024
1 parent 37d1b68 commit e82e151
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 102 deletions.
15 changes: 13 additions & 2 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,18 @@ def get_index_in_emission_prob_matrix_diploid(ref_allele, query_allele):
if query_allele == MISSING:
return MISSING_INDEX
else:
is_allele_match = ref_allele == query_allele
is_match = ref_allele == query_allele
is_ref_one = ref_allele == 1
is_query_one = query_allele == 1
return 4 * is_allele_match + 2 * is_ref_one + is_query_one
return 4 * is_match + 2 * is_ref_one + is_query_one


@jit.numba_njit
def get_index_in_emission_prob_matrix_diploid_G(ref_G, query_allele, n):
if query_allele == MISSING:
return MISSING_INDEX * np.ones((n, n), dtype=np.int64)
else:
is_match = ref_G == query_allele
is_ref_one = ref_G == 1
is_query_one = query_allele == 1
return 4 * is_match + 2 * is_ref_one + is_query_one
124 changes: 48 additions & 76 deletions lshmm/vit_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,11 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r):
V[0, j1, j2] = 1 / (n**2) * e[0, emission_index]

for l in range(1, m):
if s[0, l] == core.MISSING:
index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64)
else:
index = (
4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64)
+ 2 * (G[l, :, :] == 1).astype(np.int64)
+ np.int64(s[0, l] == 1)
)
emission_index = core.get_index_in_emission_prob_matrix_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
n=n,
)

for j1 in range(n):
for j2 in range(n):
Expand All @@ -50,7 +47,7 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r):
v[k1, k2] *= r_n[l] * (1 - r[l]) + r_n[l] ** 2
else:
v[k1, k2] *= r_n[l] ** 2
V[l, j1, j2] = np.amax(v) * e[l, index[j1, j2]]
V[l, j1, j2] = np.amax(v) * e[l, emission_index[j1, j2]]
P[l, j1, j2] = np.argmax(v)
c[l] = np.amax(V[l, :, :])
V[l, :, :] *= 1 / c[l]
Expand Down Expand Up @@ -81,17 +78,14 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r):
# see if we can pinch some ideas.
# Diploid Viterbi, with smaller memory footprint.
for l in range(1, m):
if s[0, l] == core.MISSING:
index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64)
else:
index = (
4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64)
+ 2 * (G[l, :, :] == 1).astype(np.int64)
+ np.int64(s[0, l] == 1)
)
emission_index = core.get_index_in_emission_prob_matrix_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
n=n,
)

for j1 in range(n):
for j2 in range(n):
# Get the vector to maximise over
v = np.zeros((n, n))
for k1 in range(n):
for k2 in range(n):
Expand All @@ -104,7 +98,7 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r):
v[k1, k2] *= r_n[l] * (1 - r[l]) + r_n[l] ** 2
else:
v[k1, k2] *= r_n[l] ** 2
V[j1, j2] = np.amax(v) * e[l, index[j1, j2]]
V[j1, j2] = np.amax(v) * e[l, emission_index[j1, j2]]
P[l, j1, j2] = np.argmax(v)
c[l] = np.amax(V)
V_prev = np.copy(V) / c[l]
Expand Down Expand Up @@ -133,14 +127,11 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r):

# Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM.
for l in range(1, m):
if s[0, l] == core.MISSING:
index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64)
else:
index = (
4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64)
+ 2 * (G[l, :, :] == 1).astype(np.int64)
+ np.int64(s[0, l] == 1)
)
emission_index = core.get_index_in_emission_prob_matrix_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
n=n,
)

c[l] = np.amax(V_prev)
argmax = np.argmax(V_prev)
Expand Down Expand Up @@ -183,7 +174,7 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r):
V[j1, j2] = double_switch
P[l, j1, j2] = argmax

V[j1, j2] *= e[l, index[j1, j2]]
V[j1, j2] *= e[l, emission_index[j1, j2]]
j1_j2 += 1
V_prev = np.copy(V)

Expand Down Expand Up @@ -221,14 +212,11 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r):

# Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM.
for l in range(1, m):
if s[0, l] == core.MISSING:
index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64)
else:
index = (
4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64)
+ 2 * (G[l, :, :] == 1).astype(np.int64)
+ np.int64(s[0, l] == 1)
)
emission_index = core.get_index_in_emission_prob_matrix_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
n=n,
)

c[l] = np.amax(V_prev)
argmax = np.argmax(V_prev)
Expand Down Expand Up @@ -264,7 +252,7 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r):
V[j1, j2] = double_switch
recombs_double[l] = np.append(recombs_double[l], values=j1_j2)

V[j1, j2] *= e[l, index[j1, j2]]
V[j1, j2] *= e[l, emission_index[j1, j2]]
j1_j2 += 1
V_prev = np.copy(V)

Expand Down Expand Up @@ -302,14 +290,11 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r):

# Jumped the gun - vectorising.
for l in range(1, m):
if s[0, l] == core.MISSING:
index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64)
else:
index = (
4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64)
+ 2 * (G[l, :, :] == 1).astype(np.int64)
+ np.int64(s[0, l] == 1)
)
emission_index = core.get_index_in_emission_prob_matrix_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
n=n,
)

for j1 in range(n):
for j2 in range(n):
Expand All @@ -318,7 +303,7 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r):
v[j1, :] += r_n[l] * (1 - r[l])
v[:, j2] += r_n[l] * (1 - r[l])
v *= V[l - 1, :, :]
V[l, j1, j2] = np.amax(v) * e[l, index[j1, j2]]
V[l, j1, j2] = np.amax(v) * e[l, emission_index[j1, j2]]
P[l, j1, j2] = np.argmax(v)

c[l] = np.amax(V[l, :, :])
Expand All @@ -341,34 +326,28 @@ def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r):
P = np.zeros((m, n, n), dtype=np.int64)
c = np.ones(m)

if s[0, 0] == core.MISSING:
index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64)
else:
index = (
4 * np.equal(G[0, :, :], s[0, 0]).astype(np.int64)
+ 2 * (G[0, :, :] == 1).astype(np.int64)
+ np.int64(s[0, 0] == 1)
)
V[0, :, :] = 1 / (n**2) * e[0, index]
emission_index = core.get_index_in_emission_prob_matrix_diploid_G(
ref_G=G[0, :, :],
query_allele=s[0, 0],
n=n,
)
V[0, :, :] = 1 / (n**2) * e[0, emission_index]
r_n = r / n

for l in range(1, m):
if s[0, l] == core.MISSING:
index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64)
else:
index = (
4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64)
+ 2 * (G[l, :, :] == 1).astype(np.int64)
+ np.int64(s[0, l] == 1)
)
emission_index = core.get_index_in_emission_prob_matrix_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
n=n,
)
v = (
(r_n[l] ** 2)
+ (1 - r[l]) ** 2 * char_both
+ (r_n[l] * (1 - r[l])) * (char_col + char_row)
)
v *= V[l - 1, :, :]
P[l, :, :] = np.argmax(v.reshape(n, n, -1), 2) # Have to flatten to use argmax
V[l, :, :] = v.reshape(n, n, -1)[rows, cols, P[l, :, :]] * e[l, index]
V[l, :, :] = v.reshape(n, n, -1)[rows, cols, P[l, :, :]] * e[l, emission_index]
c[l] = np.amax(V[l, :, :])
V[l, :, :] *= 1 / c[l]

Expand Down Expand Up @@ -455,17 +434,10 @@ def path_ll_dip(n, m, G, phased_path, s, e, r):
r_n = r / n

for l in range(1, m):
if s[0, l] == core.MISSING:
index = core.MISSING_INDEX
else:
index = (
4
* np.int64(
np.equal(G[l, phased_path[0][l], phased_path[1][l]], s[0, l])
)
+ 2 * np.int64(G[l, phased_path[0][l], phased_path[1][l]] == 1)
+ np.int64(s[0, l] == 1)
)
emission_index = core.get_index_in_emission_prob_matrix_diploid(
ref_allele=G[l, phased_path[0][l], phased_path[1][l]],
query_allele=s[0, l],
)

current_phase = np.array([phased_path[0][l], phased_path[1][l]])
phase_diff = np.sum(~np.equal(current_phase, old_phase))
Expand All @@ -479,7 +451,7 @@ def path_ll_dip(n, m, G, phased_path, s, e, r):
else:
log_prob_path += np.log10(r_n[l] ** 2)

log_prob_path += np.log10(e[l, index])
log_prob_path += np.log10(e[l, emission_index])
old_phase = current_phase

return log_prob_path
6 changes: 1 addition & 5 deletions tests/test_API_multiallelic.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,7 @@ def test_simple_n_8(self):
population_size=10000,
random_seed=42,
)
ts = msprime.sim_mutations(
ts,
rate=1e-4,
random_seed=42
)
ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42)
assert ts.num_sites > 5
assert ts.num_trees > 15
self.verify(ts)
Expand Down
53 changes: 34 additions & 19 deletions tests/test_LS_haploid_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,12 @@ def example_haplotypes(self, ts):
return H, haplotypes

def haplotype_emission(self, mu, m):
# Define the emission probability matrix
e = np.zeros((m, 2))
e[:, 0] = mu # If they match
e[:, 1] = 1 - mu # If they don't match
e[:, 0] = mu
e[:, 1] = 1 - mu
return e

def genotype_emission(self, mu, m):
# Define the emission probability matrix
e = np.zeros((m, 8))
e[:, core.EQUAL_BOTH_HOM] = (1 - mu) ** 2
e[:, core.UNEQUAL_BOTH_HOM] = mu**2
Expand Down Expand Up @@ -92,10 +90,8 @@ def example_parameters_haplotypes_larger(
r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2)
r[0] = 0

# Error probability
mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2)

# Define the emission probability matrix
e = self.haplotype_emission(mu, m)

for s in haplotypes:
Expand Down Expand Up @@ -167,46 +163,65 @@ def example_parameters_genotypes_larger(
r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2)
r[0] = 0

# Error probability
mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2)

# Define the emission probability matrix
e = self.genotype_emission(mu, m)

for s in genotypes:
yield n, m, G, s, e, r

def assertAllClose(self, A, B):
"""Assert that all entries of two matrices are 'close'"""
# assert np.allclose(A, B, rtol=1e-9, atol=0.0)
assert np.allclose(A, B, rtol=1e-09, atol=1e-08)

# Define a bunch of very small tree-sequences for testing a collection of parameters on
def test_simple_n_10_no_recombination(self):
ts = msprime.simulate(
10, recombination_rate=0, mutation_rate=0.5, random_seed=42
10,
recombination_rate=0,
mutation_rate=0.5,
random_seed=42,
)
assert ts.num_sites > 3
self.verify(ts)

def test_simple_n_6(self):
ts = msprime.simulate(6, recombination_rate=2, mutation_rate=7, random_seed=42)
ts = msprime.simulate(
6,
recombination_rate=2,
mutation_rate=7,
random_seed=42,
)
assert ts.num_sites > 5
self.verify(ts)

def test_simple_n_8(self):
ts = msprime.simulate(8, recombination_rate=2, mutation_rate=5, random_seed=42)
ts = msprime.simulate(
8,
recombination_rate=2,
mutation_rate=5,
random_seed=42,
)
assert ts.num_sites > 5
self.verify(ts)

def test_simple_n_8_high_recombination(self):
ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=42)
ts = msprime.simulate(
8,
recombination_rate=20,
mutation_rate=5,
random_seed=42,
)
assert ts.num_trees > 15
assert ts.num_sites > 5
self.verify(ts)

def test_simple_n_16(self):
ts = msprime.simulate(16, recombination_rate=2, mutation_rate=5, random_seed=42)
ts = msprime.simulate(
16,
recombination_rate=2,
mutation_rate=5,
random_seed=42,
)
assert ts.num_sites > 5
self.verify(ts)

Expand All @@ -233,7 +248,7 @@ class FBAlgorithmBase(LSBase):


class TestNonTreeMethodsHap(FBAlgorithmBase):
"""Test that we compute the sample likelihoods across all implementations."""
"""Test that the computed likelihoods are the same across all implementations."""

def verify(self, ts):
for n, m, H_vs, s, e_vs, r in self.example_parameters_haplotypes(ts):
Expand Down Expand Up @@ -269,7 +284,7 @@ def verify_larger(self, ts):


class TestNonTreeMethodsDip(FBAlgorithmBase):
"""Test that we compute the sample likelihoods across all implementations."""
"""Test that the computed likelihoods are the same across all implementations."""

def verify(self, ts):
for n, m, G_vs, s, e_vs, r in self.example_parameters_genotypes(ts):
Expand Down Expand Up @@ -356,7 +371,7 @@ class VitAlgorithmBase(LSBase):


class TestNonTreeViterbiHap(VitAlgorithmBase):
"""Test that we have the same log-likelihood across all implementations"""
"""Test that the computed log-likelihoods are the same across all implementations."""

def verify(self, ts):
for n, m, H_vs, s, e_vs, r in self.example_parameters_haplotypes(ts):
Expand Down Expand Up @@ -476,7 +491,7 @@ def verify_larger(self, ts):


class TestNonTreeViterbiDip(VitAlgorithmBase):
"""Test that we have the same log-likelihood across all implementations"""
"""Test that the computed log-likelihoods are the same across all implementations."""

def verify(self, ts):
for n, m, G_vs, s, e_vs, r in self.example_parameters_genotypes(ts):
Expand Down

0 comments on commit e82e151

Please sign in to comment.