Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 19, 2024
1 parent ed830e4 commit 38bee7d
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 112 deletions.
14 changes: 7 additions & 7 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# https://github.com/numba/numba/issues/1269
@jit.numba_njit
def np_apply_along_axis(func1d, axis, arr):
"""Create numpy-like functions for max, sum etc."""
"""Create Numpy-like functions for max, sum, etc."""
assert arr.ndim == 2
assert axis in [0, 1]
if axis == 0:
Expand All @@ -35,27 +35,27 @@ def np_apply_along_axis(func1d, axis, arr):

@jit.numba_njit
def np_amax(array, axis):
"""Numba implementation of numpy vectorised maximum."""
"""Numba implementation of Numpy-vectorised max."""
return np_apply_along_axis(np.amax, axis, array)


@jit.numba_njit
def np_sum(array, axis):
"""Numba implementation of numpy vectorised sum."""
"""Numba implementation of Numpy-vectorised sum."""
return np_apply_along_axis(np.sum, axis, array)


@jit.numba_njit
def np_argmax(array, axis):
"""Numba implementation of numpy vectorised argmax."""
"""Numba implementation of Numpy-vectorised argmax."""
return np_apply_along_axis(np.argmax, axis, array)


""" Functions used across different implementations of the LS HMM. """


@jit.numba_njit
def get_index_in_emission_prob_matrix(ref_allele, query_allele):
def get_index_in_emission_matrix(ref_allele, query_allele):
is_allele_match = np.equal(ref_allele, query_allele)
is_query_missing = query_allele == MISSING
if is_allele_match or is_query_missing:
Expand All @@ -64,7 +64,7 @@ def get_index_in_emission_prob_matrix(ref_allele, query_allele):


@jit.numba_njit
def get_index_in_emission_prob_matrix_diploid(ref_allele, query_allele):
def get_index_in_emission_matrix_diploid(ref_allele, query_allele):
if query_allele == MISSING:
return MISSING_INDEX
else:
Expand All @@ -75,7 +75,7 @@ def get_index_in_emission_prob_matrix_diploid(ref_allele, query_allele):


@jit.numba_njit
def get_index_in_emission_prob_matrix_diploid_G(ref_G, query_allele, n):
def get_index_in_emission_matrix_diploid_G(ref_G, query_allele, n):
if query_allele == MISSING:
return MISSING_INDEX * np.ones((n, n), dtype=np.int64)
else:
Expand Down
101 changes: 37 additions & 64 deletions lshmm/fb_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,24 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True):
c = np.ones(m)
r_n = r / n

if s[0, 0] == core.MISSING:
index = core.MISSING_INDEX * np.ones(
(n, n), dtype=np.int64
) # We could have chosen anything here, this just implies a multiplication by a constant.
else:
index = 4 * np.equal(G[0, :, :], s[0, 0]).astype(np.int64) + 2 * (
G[0, :, :] == 1
).astype(np.int64)
if s[0, 0] == 1:
index += 1

F[0, :, :] *= e[0, index]
emission_index = core.get_index_in_emission_matrix_diploid_G(
ref_G=G[0, :, :],
query_allele=s[0, 0],
n=n
)
F[0, :, :] *= e[0, emission_index]

if norm:
c[0] = np.sum(F[0, :, :])
F[0, :, :] *= 1 / c[0]

# Forwards
for l in range(1, m):
if s[0, l] == core.MISSING:
index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64)
else:
index = 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) + 2 * (
G[l, :, :] == 1
).astype(np.int64)

if s[0, l] == 1:
index += 1
emission_index = core.get_index_in_emission_matrix_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
n=n
)

# No change in both
F[l, :, :] = (1 - r[l]) ** 2 * F[l - 1, :, :]
Expand All @@ -58,23 +48,19 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True):
F[l, :, :] += ((1 - r[l]) * r_n[l]) * (sum_j + sum_j.T)

# Emission
F[l, :, :] *= e[l, index]
F[l, :, :] *= e[l, emission_index]
c[l] = np.sum(F[l, :, :])
F[l, :, :] *= 1 / c[l]

ll = np.sum(np.log10(c))
else:
# Forwards
for l in range(1, m):
if s[0, l] == core.MISSING:
index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64)
else:
index = 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) + 2 * (
G[l, :, :] == 1
).astype(np.int64)

if s[0, l] == 1:
index += 1
emission_index = core.get_index_in_emission_matrix_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
n=n
)

# No change in both
F[l, :, :] = (1 - r[l]) ** 2 * F[l - 1, :, :]
Expand All @@ -88,7 +74,7 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True):
F[l, :, :] += ((1 - r[l]) * r_n[l]) * (sum_j + sum_j.T)

# Emission
F[l, :, :] *= e[l, index]
F[l, :, :] *= e[l, emission_index]

ll = np.log10(np.sum(F[l, :, :]))

Expand All @@ -104,30 +90,25 @@ def backwards_ls_dip(n, m, G, s, e, c, r):

# Backwards
for l in range(m - 2, -1, -1):
if s[0, l + 1] == core.MISSING:
index = core.MISSING_INDEX * np.ones(
(n, n), dtype=np.int64
) # We could have chosen anything here, this just implies a multiplication by a constant.
else:
index = (
4 * np.equal(G[l + 1, :, :], s[0, l + 1]).astype(np.int64)
+ 2 * (G[l + 1, :, :] == 1).astype(np.int64)
+ np.int64(s[0, l + 1] == 1)
)
emission_index = core.get_index_in_emission_matrix_diploid_G(
ref_G=G[l + 1, :, :],
query_allele=s[0, l + 1],
n=n
)

# No change in both
B[l, :, :] = r_n[l + 1] ** 2 * np.sum(
e[l + 1, index.reshape((n, n))] * B[l + 1, :, :]
e[l + 1, emission_index.reshape((n, n))] * B[l + 1, :, :]
)

# Both change
B[l, :, :] += (
(1 - r[l + 1]) ** 2 * B[l + 1, :, :] * e[l + 1, index.reshape((n, n))]
(1 - r[l + 1]) ** 2 * B[l + 1, :, :] * e[l + 1, emission_index.reshape((n, n))]
)

# One changes
sum_j = (
core.np_sum(B[l + 1, :, :] * e[l + 1, index], 0).repeat(n).reshape((-1, n))
core.np_sum(B[l + 1, :, :] * e[l + 1, emission_index], 0).repeat(n).reshape((-1, n))
)
B[l, :, :] += ((1 - r[l + 1]) * r_n[l + 1]) * (sum_j + sum_j.T)
B[l, :, :] *= 1 / c[l + 1]
Expand All @@ -145,15 +126,11 @@ def forward_ls_dip_starting_point(n, m, G, s, e, r):
for j1 in range(n):
for j2 in range(n):
F[0, j1, j2] = 1 / (n**2)
if s[0, 0] == core.MISSING:
index_tmp = core.MISSING_INDEX
else:
index_tmp = (
4 * np.int64(np.equal(G[0, j1, j2], s[0, 0]))
+ 2 * np.int64((G[0, j1, j2] == 1))
+ np.int64(s[0, 0] == 1)
)
F[0, j1, j2] *= e[0, index_tmp]
emission_index = core.get_index_in_emission_matrix_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0]
)
F[0, j1, j2] *= e[0, emission_index]

for l in range(1, m):
F_no_change = np.zeros((n, n))
Expand Down Expand Up @@ -293,21 +270,17 @@ def backward_ls_dip_starting_point(n, m, G, s, e, r):

@jit.numba_njit
def forward_ls_dip_loop(n, m, G, s, e, r, norm=True):
"""LS diploid forwards algoritm without vectorisation."""
"""LS diploid forwards algorithm without vectorisation."""
# Initialise
F = np.zeros((m, n, n))
for j1 in range(n):
for j2 in range(n):
F[0, j1, j2] = 1 / (n**2)
if s[0, 0] == core.MISSING:
index_tmp = core.MISSING_INDEX
else:
index_tmp = (
4 * np.int64(np.equal(G[0, j1, j2], s[0, 0]))
+ 2 * np.int64((G[0, j1, j2] == 1))
+ np.int64(s[0, 0] == 1)
)
F[0, j1, j2] *= e[0, index_tmp]
emission_index = core.get_index_in_emission_matrix_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0]
)
F[0, j1, j2] *= e[0, emission_index]
r_n = r / n
c = np.ones(m)

Expand Down
20 changes: 10 additions & 10 deletions lshmm/fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
if norm:
c = np.zeros(m)
for i in range(n):
emission_idx = core.get_index_in_emission_prob_matrix(
emission_index = core.get_index_in_emission_matrix(
ref_allele=H[0, i], query_allele=s[0, 0]
)
F[0, i] = 1 / n * e[0, emission_idx]
F[0, i] = 1 / n * e[0, emission_index]
c[0] += F[0, i]

for i in range(n):
Expand All @@ -31,10 +31,10 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
for l in range(1, m):
for i in range(n):
F[l, i] = F[l - 1, i] * (1 - r[l]) + r_n[l]
emission_idx = core.get_index_in_emission_prob_matrix(
emission_index = core.get_index_in_emission_matrix(
ref_allele=H[l, i], query_allele=s[0, l]
)
F[l, i] *= e[l, emission_idx]
F[l, i] *= e[l, emission_index]
c[l] += F[l, i]

for i in range(n):
Expand All @@ -46,19 +46,19 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
c = np.ones(m)

for i in range(n):
emission_idx = core.get_index_in_emission_prob_matrix(
emission_index = core.get_index_in_emission_matrix(
ref_allele=H[0, i], query_allele=s[0, 0]
)
F[0, i] = 1 / n * e[0, emission_idx]
F[0, i] = 1 / n * e[0, emission_index]

# Forwards pass
for l in range(1, m):
for i in range(n):
F[l, i] = F[l - 1, i] * (1 - r[l]) + np.sum(F[l - 1, :]) * r_n[l]
emission_idx = core.get_index_in_emission_prob_matrix(
emission_index = core.get_index_in_emission_matrix(
ref_allele=H[l, i], query_allele=s[0, l]
)
F[l, i] *= e[l, emission_idx]
F[l, i] *= e[l, emission_index]

ll = np.log10(np.sum(F[m - 1, :]))

Expand All @@ -78,10 +78,10 @@ def backwards_ls_hap(n, m, H, s, e, c, r):
tmp_B = np.zeros(n)
tmp_B_sum = 0
for i in range(n):
emission_idx = core.get_index_in_emission_prob_matrix(
emission_index = core.get_index_in_emission_matrix(
ref_allele=H[l + 1, i], query_allele=s[0, l + 1]
)
tmp_B[i] = e[l + 1, emission_idx] * B[l + 1, i]
tmp_B[i] = e[l + 1, emission_index] * B[l + 1, i]
tmp_B_sum += tmp_B[i]
for i in range(n):
B[l, i] = r_n[l + 1] * tmp_B_sum
Expand Down
Loading

0 comments on commit 38bee7d

Please sign in to comment.