Skip to content

Commit

Permalink
WIP diploid
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 19, 2024
1 parent a0472e4 commit 9dac3c0
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 58 deletions.
21 changes: 20 additions & 1 deletion lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,23 @@ def np_argmax(array, axis):

@jit.numba_njit
def get_index_in_emission_prob_matrix(ref_allele, query_allele):
return np.int64(np.equal(ref_allele, query_allele) or query_allele == MISSING)
is_allele_match = np.equal(ref_allele, query_allele)
is_query_missing = query_allele == MISSING
if is_allele_match or is_query_missing:
return 1
return 0


@jit.numba_njit
def get_index_in_emission_prob_matrix_diploid(ref_allele, query_allele):
if query_allele == MISSING:
return MISSING_INDEX
else:
is_allele_match = ref_allele == query_allele
is_ref_one = ref_allele == 1
is_query_one = query_allele == 1
return (
4 * is_allele_match
+ 2 * is_ref_one
+ is_query_one
)
93 changes: 36 additions & 57 deletions lshmm/vit_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,11 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r):

for j1 in range(n):
for j2 in range(n):
if s[0, 0] == core.MISSING:
index_tmp = core.MISSING_INDEX
else:
index_tmp = (
4 * np.int64(np.equal(G[0, j1, j2], s[0, 0]))
+ 2 * np.int64((G[0, j1, j2] == 1))
+ np.int64(s[0, 0] == 1)
)
V[0, j1, j2] = 1 / (n**2) * e[0, index_tmp]
emission_index = core.get_index_in_emission_prob_matrix_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0]
)
V[0, j1, j2] = 1 / (n**2) * e[0, emission_index]

for l in range(1, m):
if s[0, l] == core.MISSING:
Expand Down Expand Up @@ -77,15 +73,11 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r):

for j1 in range(n):
for j2 in range(n):
if s[0, 0] == core.MISSING:
index_tmp = core.MISSING_INDEX
else:
index_tmp = (
4 * np.int64(np.equal(G[0, j1, j2], s[0, 0]))
+ 2 * np.int64((G[0, j1, j2] == 1))
+ np.int64(s[0, 0] == 1)
)
V_previous[j1, j2] = 1 / (n**2) * e[0, index_tmp]
emission_index = core.get_index_in_emission_prob_matrix_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0]
)
V_previous[j1, j2] = 1 / (n**2) * e[0, emission_index]

# Take a look at the haploid Viterbi implementation in Jerome's code, and
# see if we can pinch some ideas.
Expand Down Expand Up @@ -136,15 +128,11 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r):

for j1 in range(n):
for j2 in range(n):
if s[0, 0] == core.MISSING:
index_tmp = core.MISSING_INDEX
else:
index_tmp = (
4 * np.int64(np.equal(G[0, j1, j2], s[0, 0]))
+ 2 * np.int64((G[0, j1, j2] == 1))
+ np.int64(s[0, 0] == 1)
)
V_previous[j1, j2] = 1 / (n**2) * e[0, index_tmp]
emission_index = core.get_index_in_emission_prob_matrix_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0]
)
V_previous[j1, j2] = 1 / (n**2) * e[0, emission_index]

# Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM.
for l in range(1, m):
Expand Down Expand Up @@ -209,7 +197,7 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r):

@jit.numba_njit
def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r):
"""LS diploid Viterbi algorithm, with reduced memory."""
"""An implementation with reduced memory and no pointer."""
# Initialise
V = np.zeros((n, n))
V_previous = np.zeros((n, n))
Expand All @@ -229,15 +217,11 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r):

for j1 in range(n):
for j2 in range(n):
if s[0, 0] == core.MISSING:
index_tmp = core.MISSING_INDEX
else:
index_tmp = (
4 * np.int64(np.equal(G[0, j1, j2], s[0, 0]))
+ 2 * np.int64((G[0, j1, j2] == 1))
+ np.int64(s[0, 0] == 1)
)
V_previous[j1, j2] = 1 / (n**2) * e[0, index_tmp]
emission_index = core.get_index_in_emission_prob_matrix_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0]
)
V_previous[j1, j2] = 1 / (n**2) * e[0, emission_index]

# Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM.
for l in range(1, m):
Expand Down Expand Up @@ -315,15 +299,11 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r):

for j1 in range(n):
for j2 in range(n):
if s[0, 0] == core.MISSING:
index_tmp = core.MISSING_INDEX
else:
index_tmp = (
4 * np.int64(np.equal(G[0, j1, j2], s[0, 0]))
+ 2 * np.int64((G[0, j1, j2] == 1))
+ np.int64(s[0, 0] == 1)
)
V[0, j1, j2] = 1 / (n**2) * e[0, index_tmp]
emission_index = core.get_index_in_emission_prob_matrix_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0]
)
V[0, j1, j2] = 1 / (n**2) * e[0, emission_index]

# Jumped the gun - vectorising.
for l in range(1, m):
Expand Down Expand Up @@ -365,6 +345,7 @@ def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r):
V = np.zeros((m, n, n))
P = np.zeros((m, n, n), dtype=np.int64)
c = np.ones(m)

if s[0, 0] == core.MISSING:
index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64)
else:
Expand Down Expand Up @@ -406,7 +387,8 @@ def backwards_viterbi_dip(m, V_last, P):
"""Run a backwards pass to determine the most likely path."""
assert V_last.ndim == 2
assert V_last.shape[0] == V_last.shape[1]
# Initialisation

# Initialise
path = np.zeros(m, dtype=np.int64)
path[m - 1] = np.argmax(V_last)

Expand Down Expand Up @@ -438,7 +420,8 @@ def backwards_viterbi_dip_no_pointer(
"""Run a backwards pass to determine the most likely path."""
assert V_last.ndim == 2
assert V_last.shape[0] == V_last.shape[1]
# Initialisation

# Initialise
path = np.zeros(m, dtype=np.int64)
path[m - 1] = np.argmax(V_last)
n = V_last.shape[0]
Expand Down Expand Up @@ -469,15 +452,11 @@ def get_phased_path(n, path):
@jit.numba_njit
def path_ll_dip(n, m, G, phased_path, s, e, r):
"""Evaluate log-likelihood path through a reference panel which results in sequence s."""
if s[0, 0] == core.MISSING:
index = core.MISSING_INDEX
else:
index = (
4 * np.int64(np.equal(G[0, phased_path[0][0], phased_path[1][0]], s[0, 0]))
+ 2 * np.int64(G[0, phased_path[0][0], phased_path[1][0]] == 1)
+ np.int64(s[0, 0] == 1)
)
log_prob_path = np.log10(1 / (n**2) * e[0, index])
emission_index = core.get_index_in_emission_prob_matrix_diploid(
ref_allele=G[0, phased_path[0][0], phased_path[1][0]],
query_allele=s[0, 0]
)
log_prob_path = np.log10(1 / (n**2) * e[0, emission_index])
old_phase = np.array([phased_path[0][0], phased_path[1][0]])
r_n = r / n

Expand Down

0 comments on commit 9dac3c0

Please sign in to comment.