Skip to content

Commit

Permalink
WIP diploid
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 19, 2024
1 parent 9dac3c0 commit fd011ca
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 49 deletions.
50 changes: 25 additions & 25 deletions lshmm/vit_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r):
"""A naive implementation with reduced memory."""
# Initialise
V = np.zeros((n, n))
V_previous = np.zeros((n, n))
V_prev = np.zeros((n, n))
P = np.zeros((m, n, n), dtype=np.int64)
c = np.ones(m)
r_n = r / n
Expand All @@ -77,7 +77,7 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r):
ref_allele=G[0, j1, j2],
query_allele=s[0, 0]
)
V_previous[j1, j2] = 1 / (n**2) * e[0, emission_index]
V_prev[j1, j2] = 1 / (n**2) * e[0, emission_index]

# Take a look at the haploid Viterbi implementation in Jerome's code, and
# see if we can pinch some ideas.
Expand All @@ -97,7 +97,7 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r):
v = np.zeros((n, n))
for k1 in range(n):
for k2 in range(n):
v[k1, k2] = V_previous[k1, k2]
v[k1, k2] = V_prev[k1, k2]
if (k1 == j1) and (k2 == j2):
v[k1, k2] *= (
(1 - r[l]) ** 2 + 2 * (1 - r[l]) * r_n[l] + r_n[l] ** 2
Expand All @@ -109,7 +109,7 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r):
V[j1, j2] = np.amax(v) * e[l, index[j1, j2]]
P[l, j1, j2] = np.argmax(v)
c[l] = np.amax(V)
V_previous = np.copy(V) / c[l]
V_prev = np.copy(V) / c[l]

ll = np.sum(np.log10(c))

Expand All @@ -121,7 +121,7 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r):
"""An implementation with reduced memory."""
# Initialise
V = np.zeros((n, n))
V_previous = np.zeros((n, n))
V_prev = np.zeros((n, n))
P = np.zeros((m, n, n), dtype=np.int64)
c = np.ones(m)
r_n = r / n
Expand All @@ -132,7 +132,7 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r):
ref_allele=G[0, j1, j2],
query_allele=s[0, 0]
)
V_previous[j1, j2] = 1 / (n**2) * e[0, emission_index]
V_prev[j1, j2] = 1 / (n**2) * e[0, emission_index]

# Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM.
for l in range(1, m):
Expand All @@ -145,12 +145,12 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r):
+ np.int64(s[0, l] == 1)
)

c[l] = np.amax(V_previous)
argmax = np.argmax(V_previous)
c[l] = np.amax(V_prev)
argmax = np.argmax(V_prev)

V_previous *= 1 / c[l]
V_rowcol_max = core.np_amax(V_previous, 0)
arg_rowcol_max = core.np_argmax(V_previous, 0)
V_prev *= 1 / c[l]
V_rowcol_max = core.np_amax(V_prev, 0)
arg_rowcol_max = core.np_argmax(V_prev, 0)

no_switch = (1 - r[l]) ** 2 + 2 * (r_n[l] * (1 - r[l])) + r_n[l] ** 2
single_switch = r_n[l] * (1 - r[l]) + r_n[l] ** 2
Expand All @@ -170,7 +170,7 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r):
else:
template_single_switch = arg_rowcol_max[j2] * n + j2

V[j1, j2] = V_previous[j1, j2] * no_switch # No switch in either
V[j1, j2] = V_prev[j1, j2] * no_switch # No switch in either
P[l, j1, j2] = j1_j2

# Single or double switch?
Expand All @@ -188,7 +188,7 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r):

V[j1, j2] *= e[l, index[j1, j2]]
j1_j2 += 1
V_previous = np.copy(V)
V_prev = np.copy(V)

ll = np.sum(np.log10(c)) + np.log10(np.amax(V))

Expand All @@ -200,7 +200,7 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r):
"""An implementation with reduced memory and no pointer."""
# Initialise
V = np.zeros((n, n))
V_previous = np.zeros((n, n))
V_prev = np.zeros((n, n))
c = np.ones(m)
r_n = r / n

Expand All @@ -221,7 +221,7 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r):
ref_allele=G[0, j1, j2],
query_allele=s[0, 0]
)
V_previous[j1, j2] = 1 / (n**2) * e[0, emission_index]
V_prev[j1, j2] = 1 / (n**2) * e[0, emission_index]

# Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM.
for l in range(1, m):
Expand All @@ -234,14 +234,14 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r):
+ np.int64(s[0, l] == 1)
)

c[l] = np.amax(V_previous)
argmax = np.argmax(V_previous)
c[l] = np.amax(V_prev)
argmax = np.argmax(V_prev)
V_argmaxes[l - 1] = argmax # added

V_previous *= 1 / c[l]
V_rowcol_max = core.np_amax(V_previous, 0)
V_prev *= 1 / c[l]
V_rowcol_max = core.np_amax(V_prev, 0)
V_rowcol_maxes[l - 1, :] = V_rowcol_max
arg_rowcol_max = core.np_argmax(V_previous, 0)
arg_rowcol_max = core.np_argmax(V_prev, 0)
V_rowcol_argmaxes[l - 1, :] = arg_rowcol_max

no_switch = (1 - r[l]) ** 2 + 2 * (r_n[l] * (1 - r[l])) + r_n[l] ** 2
Expand All @@ -253,7 +253,7 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r):
for j1 in range(n):
for j2 in range(n):
V_single_switch = max(V_rowcol_max[j1], V_rowcol_max[j2])
V[j1, j2] = V_previous[j1, j2] * no_switch # No switch in either
V[j1, j2] = V_prev[j1, j2] * no_switch # No switch in either

# Single or double switch?
single_switch_tmp = single_switch * V_single_switch
Expand All @@ -270,11 +270,11 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r):

V[j1, j2] *= e[l, index[j1, j2]]
j1_j2 += 1
V_previous = np.copy(V)
V_prev = np.copy(V)

V_argmaxes[m - 1] = np.argmax(V_previous)
V_rowcol_maxes[m - 1, :] = core.np_amax(V_previous, 0)
V_rowcol_argmaxes[m - 1, :] = core.np_argmax(V_previous, 0)
V_argmaxes[m - 1] = np.argmax(V_prev)
V_rowcol_maxes[m - 1, :] = core.np_amax(V_prev, 0)
V_rowcol_argmaxes[m - 1, :] = core.np_argmax(V_prev, 0)
ll = np.sum(np.log10(c)) + np.log10(np.amax(V))

return (
Expand Down
50 changes: 26 additions & 24 deletions lshmm/vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def viterbi_naive_init(n, m, H, s, e, r):
@jit.numba_njit
def viterbi_init(n, m, H, s, e, r):
"""Initialise a naive, but more memory efficient, implementation."""
V_previous = np.zeros(n)
V_prev = np.zeros(n)
V = np.zeros(n)
P = np.zeros((m, n), dtype=np.int64)
r_n = r / n
Expand All @@ -39,14 +39,14 @@ def viterbi_init(n, m, H, s, e, r):
ref_allele=H[0, i],
query_allele=s[0, 0]
)
V_previous[i] = 1 / n * e[0, emission_idx]
V_prev[i] = 1 / n * e[0, emission_idx]

return V, V_previous, P, r_n
return V, V_prev, P, r_n


@jit.numba_njit
def forwards_viterbi_hap_naive(n, m, H, s, e, r):
"""A naive implementation."""
"""A naive implementation of the forward pass."""
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r)

for j in range(1, m):
Expand All @@ -72,7 +72,7 @@ def forwards_viterbi_hap_naive(n, m, H, s, e, r):

@jit.numba_njit
def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r):
"""A naive matrix-based implementation of the forward algorithm using Numpy."""
"""A naive matrix-based implementation of the forward pass using Numpy."""
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r)

for j in range(1, m):
Expand All @@ -95,8 +95,8 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r):

@jit.numba_njit
def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r):
"""A naive implementation with reduced memory."""
V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r)
"""A naive implementation of the forward pass with reduced memory."""
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r)

for j in range(1, m):
for i in range(n):
Expand All @@ -106,14 +106,14 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r):
ref_allele=H[j, i],
query_allele=s[0, j]
)
v[k] = V_previous[k] * e[j, emission_idx]
v[k] = V_prev[k] * e[j, emission_idx]
if k == i:
v[k] *= 1 - r[j] + r_n[j]
else:
v[k] *= r_n[j]
P[j, i] = np.argmax(v)
V[i] = v[P[j, i]]
V_previous = np.copy(V)
V_prev = np.copy(V)

ll = np.log10(np.amax(V))

Expand All @@ -122,29 +122,28 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r):

@jit.numba_njit
def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r):
"""A naive implementation with reduced memory and rescaling."""
V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r)
"""A naive implementation of the forward pass with reduced memory and rescaling."""
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r)
c = np.ones(m)

for j in range(1, m):
c[j] = np.amax(V_previous)
V_previous *= 1 / c[j]
c[j] = np.amax(V_prev)
V_prev *= 1 / c[j]
for i in range(n):
v = np.zeros(n)
for k in range(n):
emission_idx = core.get_index_in_emission_prob_matrix(
ref_allele=H[j, i],
query_allele=s[0, j]
)
v[k] = V_previous[k] * e[j, emission_idx]
v[k] = V_prev[k] * e[j, emission_idx]
if k == i:
v[k] *= 1 - r[j] + r_n[j]
else:
v[k] *= r_n[j]
P[j, i] = np.argmax(v)
V[i] = v[P[j, i]]

V_previous = np.copy(V)
V_prev = np.copy(V)

ll = np.sum(np.log10(c)) + np.log10(np.amax(V))

Expand All @@ -154,16 +153,16 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r):
@jit.numba_njit
def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r):
"""An implementation with reduced memory that exploits the Markov structure."""
V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r)
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r)
c = np.ones(m)

for j in range(1, m):
argmax = np.argmax(V_previous)
c[j] = V_previous[argmax]
V_previous *= 1 / c[j]
argmax = np.argmax(V_prev)
c[j] = V_prev[argmax]
V_prev *= 1 / c[j]
V = np.zeros(n)
for i in range(n):
V[i] = V_previous[i] * (1 - r[j] + r_n[j])
V[i] = V_prev[i] * (1 - r[j] + r_n[j])
P[j, i] = i
if V[i] < r_n[j]:
V[i] = r_n[j]
Expand All @@ -173,7 +172,7 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r):
query_allele=s[0, j]
)
V[i] *= e[j, emission_idx]
V_previous = np.copy(V)
V_prev = np.copy(V)

ll = np.sum(np.log10(c)) + np.log10(np.max(V))

Expand All @@ -185,7 +184,10 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
"""An implementation with even smaller memory footprint that exploits the Markov structure."""
V = np.zeros(n)
for i in range(n):
emission_idx = core.get_index_in_emission_prob_matrix(H[0, i], s[0, 0])
emission_idx = core.get_index_in_emission_prob_matrix(
ref_allele=H[0, i],
query_allele=s[0, 0]
)
V[i] = 1 / n * e[0, emission_idx]
P = np.zeros((m, n), dtype=np.int64)
r_n = r / n
Expand Down Expand Up @@ -214,7 +216,7 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):

@jit.numba_njit
def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r):
"""LS haploid Viterbi algorithm with even smaller memory footprint and exploits the Markov process structure."""
"""An implementation with even smaller memory footprint and rescaling that exploits the Markov structure."""
V = np.zeros(n)
for i in range(n):
emission_idx = core.get_index_in_emission_prob_matrix(
Expand Down

0 comments on commit fd011ca

Please sign in to comment.