Skip to content

Commit

Permalink
Modify haploid Viterbi to handle NONCOPY state in reference panel
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Mar 26, 2024
1 parent f94dd05 commit a5c3a7b
Show file tree
Hide file tree
Showing 2 changed files with 291 additions and 25 deletions.
70 changes: 45 additions & 25 deletions lshmm/vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import jit

MISSING = -1
NONCOPY = -2


@jit.numba_njit
Expand All @@ -13,10 +14,10 @@ def viterbi_naive_init(n, m, H, s, e, r):
P = np.zeros((m, n)).astype(np.int64)
r_n = r / n
for i in range(n):
V[0, i] = (
1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
)

em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
V[0, i] = 1 / n * em_prob
return V, P, r_n


Expand All @@ -29,9 +30,10 @@ def viterbi_init(n, m, H, s, e, r):
r_n = r / n

for i in range(n):
V_previous[i] = (
1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
)
em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
V_previous[i] = 1 / n * em_prob

return V, V_previous, P, r_n

Expand All @@ -47,10 +49,10 @@ def forwards_viterbi_hap_naive(n, m, H, s, e, r):
# Get the vector to maximise over
v = np.zeros(n)
for k in range(n):
v[k] = (
e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
* V[j - 1, k]
)
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
v[k] = em_prob * V[j - 1, k]
if k == i:
v[k] *= 1 - r[j] + r_n[j]
else:
Expand All @@ -74,7 +76,10 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r):
for i in range(n):
v = np.copy(v_tmp)
v[i] += V[j - 1, i] * (1 - r[j])
v *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
v *= em_prob
P[j, i] = np.argmax(v)
V[j, i] = v[P[j, i]]

Expand All @@ -94,10 +99,10 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r):
# Get the vector to maximise over
v = np.zeros(n)
for k in range(n):
v[k] = (
e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
* V_previous[k]
)
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
v[k] = (em_prob * V_previous[k])
if k == i:
v[k] *= 1 - r[j] + r_n[j]
else:
Expand Down Expand Up @@ -125,10 +130,10 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r):
# Get the vector to maximise over
v = np.zeros(n)
for k in range(n):
v[k] = (
e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
* V_previous[k]
)
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
v[k] = em_prob * V_previous[k]
if k == i:
v[k] *= 1 - r[j] + r_n[j]
else:
Expand Down Expand Up @@ -161,7 +166,10 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r):
if V[i] < r_n[j]:
V[i] = r_n[j]
P[j, i] = argmax
V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
V[i] *= em_prob
V_previous = np.copy(V)

ll = np.sum(np.log10(c)) + np.log10(np.max(V))
Expand All @@ -175,7 +183,10 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
# Initialise
V = np.zeros(n)
for i in range(n):
V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
V[i] = 1 / n * em_prob
P = np.zeros((m, n)).astype(np.int64)
r_n = r / n
c = np.ones(m)
Expand All @@ -190,7 +201,10 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
if V[i] < r_n[j]:
V[i] = r_n[j]
P[j, i] = argmax
V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
V[i] *= em_prob

ll = np.sum(np.log10(c)) + np.log10(np.max(V))

Expand All @@ -203,7 +217,10 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r):
# Initialise
V = np.zeros(n)
for i in range(n):
V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
V[i] = 1 / n * em_prob
r_n = r / n
c = np.ones(m)
recombs = [
Expand All @@ -224,7 +241,10 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r):
recombs[j] = np.append(
recombs[j], i
) # We add template i as a potential template to recombine to at site j.
V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
V[i] *= em_prob

V_argmaxes[m - 1] = np.argmax(V)
ll = np.sum(np.log10(c)) + np.log10(np.max(V))
Expand Down
Loading

0 comments on commit a5c3a7b

Please sign in to comment.