Skip to content

Commit

Permalink
Add argument for pass function to define emission probabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jul 1, 2024
1 parent 5f49c42 commit 34e356d
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 35 deletions.
18 changes: 11 additions & 7 deletions lshmm/fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


@jit.numba_njit
def forwards_ls_hap(n, m, H, s, e, r, norm=True):
def forwards_ls_hap(
n, m, H, s, e, r, norm=True, *, emission_func=core.get_emission_probability_haploid
):
"""
A matrix-based implementation using Numpy.
Expand All @@ -20,7 +22,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
if norm:
c = np.zeros(m)
for i in range(n):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[0, i],
query_allele=s[0, 0],
site=0,
Expand All @@ -36,7 +38,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
for l in range(1, m):
for i in range(n):
F[l, i] = F[l - 1, i] * (1 - r[l]) + r_n[l]
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[l, i],
query_allele=s[0, l],
site=l,
Expand All @@ -53,7 +55,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
else:
c = np.ones(m)
for i in range(n):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[0, i],
query_allele=s[0, 0],
site=0,
Expand All @@ -65,7 +67,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
for l in range(1, m):
for i in range(n):
F[l, i] = F[l - 1, i] * (1 - r[l]) + np.sum(F[l - 1, :]) * r_n[l]
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[l, i],
query_allele=s[0, l],
site=l,
Expand All @@ -79,7 +81,9 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):


@jit.numba_njit
def backwards_ls_hap(n, m, H, s, e, c, r):
def backwards_ls_hap(
n, m, H, s, e, c, r, *, emission_func=core.get_emission_probability_haploid
):
"""
A matrix-based implementation using Numpy.
Expand All @@ -96,7 +100,7 @@ def backwards_ls_hap(n, m, H, s, e, c, r):
tmp_B = np.zeros(n)
tmp_B_sum = 0
for i in range(n):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[l + 1, i],
query_allele=s[0, l + 1],
site=l + 1,
Expand Down
76 changes: 48 additions & 28 deletions lshmm/vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@


@jit.numba_njit
def viterbi_naive_init(n, m, H, s, e, r):
def viterbi_naive_init(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
):
"""Initialise a naive implementation."""
V = np.zeros((m, n))
P = np.zeros((m, n), dtype=np.int64)
num_copiable_entries = core.get_num_copiable_entries(H)
r_n = r / num_copiable_entries

for i in range(n):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[0, i],
query_allele=s[0, 0],
site=0,
Expand All @@ -27,7 +29,9 @@ def viterbi_naive_init(n, m, H, s, e, r):


@jit.numba_njit
def viterbi_init(n, m, H, s, e, r):
def viterbi_init(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
):
"""Initialise a naive, but more memory efficient, implementation."""
V_prev = np.zeros(n)
V = np.zeros(n)
Expand All @@ -36,7 +40,7 @@ def viterbi_init(n, m, H, s, e, r):
r_n = r / num_copiable_entries

for i in range(n):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[0, i],
query_allele=s[0, 0],
site=0,
Expand All @@ -48,15 +52,17 @@ def viterbi_init(n, m, H, s, e, r):


@jit.numba_njit
def forwards_viterbi_hap_naive(n, m, H, s, e, r):
def forwards_viterbi_hap_naive(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
):
"""A naive implementation of the forward pass."""
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r)
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func)

for j in range(1, m):
for i in range(n):
v = np.zeros(n)
for k in range(n):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[j, i],
query_allele=s[0, j],
site=j,
Expand All @@ -76,16 +82,18 @@ def forwards_viterbi_hap_naive(n, m, H, s, e, r):


@jit.numba_njit
def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r):
def forwards_viterbi_hap_naive_vec(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
):
"""A naive matrix-based implementation of the forward pass."""
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r)
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func)

for j in range(1, m):
v_tmp = V[j - 1, :] * r_n[j]
for i in range(n):
v = np.copy(v_tmp)
v[i] += V[j - 1, i] * (1 - r[j])
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[j, i],
query_allele=s[0, j],
site=j,
Expand All @@ -101,15 +109,17 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r):


@jit.numba_njit
def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r):
def forwards_viterbi_hap_naive_low_mem(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
):
"""A naive implementation of the forward pass with reduced memory."""
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r)
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func)

for j in range(1, m):
for i in range(n):
v = np.zeros(n)
for k in range(n):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[j, i],
query_allele=s[0, j],
site=j,
Expand All @@ -130,9 +140,11 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r):


@jit.numba_njit
def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r):
def forwards_viterbi_hap_naive_low_mem_rescaling(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
):
"""A naive implementation of the forward pass with reduced memory and rescaling."""
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r)
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func)
c = np.ones(m)

for j in range(1, m):
Expand All @@ -141,7 +153,7 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r):
for i in range(n):
v = np.zeros(n)
for k in range(n):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[j, i],
query_allele=s[0, j],
site=j,
Expand All @@ -162,9 +174,11 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r):


@jit.numba_njit
def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r):
def forwards_viterbi_hap_low_mem_rescaling(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
):
"""An implementation with reduced memory that exploits the Markov structure."""
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r)
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func)
c = np.ones(m)

for j in range(1, m):
Expand All @@ -178,7 +192,7 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r):
if V[i] < r_n[j]:
V[i] = r_n[j]
P[j, i] = argmax
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[j, i],
query_allele=s[0, j],
site=j,
Expand All @@ -193,7 +207,9 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r):


@jit.numba_njit
def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
def forwards_viterbi_hap_lower_mem_rescaling(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
):
"""
An implementation with even smaller memory footprint
that exploits the Markov structure.
Expand All @@ -202,7 +218,7 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
"""
V = np.zeros(n)
for i in range(n):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[0, i],
query_allele=s[0, 0],
site=0,
Expand All @@ -224,7 +240,7 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
if V[i] < r_n[j]:
V[i] = r_n[j]
P[j, i] = argmax
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[j, i],
query_allele=s[0, j],
site=j,
Expand All @@ -238,14 +254,16 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):


@jit.numba_njit
def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r):
def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
):
"""
An implementation with even smaller memory footprint and rescaling
that exploits the Markov structure.
"""
V = np.zeros(n)
for i in range(n):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[0, i],
query_allele=s[0, 0],
site=0,
Expand Down Expand Up @@ -273,7 +291,7 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r):
recombs[j] = np.append(
recombs[j], i
) # We add template i as a potential template to recombine to at site j.
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[j, i],
query_allele=s[0, j],
site=j,
Expand Down Expand Up @@ -320,13 +338,15 @@ def backwards_viterbi_hap_no_pointer(m, V_argmaxes, recombs):


@jit.numba_njit
def path_ll_hap(n, m, H, path, s, e, r):
def path_ll_hap(
n, m, H, path, s, e, r, *, emission_func=core.get_emission_probability_haploid
):
"""
Evaluate the log-likelihood of a path through a reference panel resulting in a query.
This is exposed via the API.
"""
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[0, path[0]],
query_allele=s[0, 0],
site=0,
Expand All @@ -338,7 +358,7 @@ def path_ll_hap(n, m, H, path, s, e, r):
r_n = r / num_copiable_entries

for l in range(1, m):
emission_prob = core.get_emission_probability_haploid(
emission_prob = emission_func(
ref_allele=H[l, path[l]],
query_allele=s[0, l],
site=l,
Expand Down

0 comments on commit 34e356d

Please sign in to comment.