Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 20, 2024
1 parent f3a99ef commit b6c9c71
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 268 deletions.
4 changes: 2 additions & 2 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def np_argmax(array, axis):


@jit.numba_njit
def get_index_in_emission_matrix(ref_allele, query_allele):
is_allele_match = np.equal(ref_allele, query_allele)
def get_index_in_emission_matrix_haploid(ref_allele, query_allele):
is_allele_match = ref_allele == query_allele
is_query_missing = query_allele == MISSING
if is_allele_match or is_query_missing:
return 1
Expand Down
15 changes: 7 additions & 8 deletions lshmm/fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@

@jit.numba_njit
def forwards_ls_hap(n, m, H, s, e, r, norm=True):
"""A matrix-based implementation using Numpy vectorisation."""
"""A matrix-based implementation using Numpy."""
F = np.zeros((m, n))
r_n = r / n

if norm:
c = np.zeros(m)
for i in range(n):
emission_index = core.get_index_in_emission_matrix(
emission_index = core.get_index_in_emission_matrix_haploid(
ref_allele=H[0, i], query_allele=s[0, 0]
)
F[0, i] = 1 / n * e[0, emission_index]
Expand All @@ -31,7 +31,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
for l in range(1, m):
for i in range(n):
F[l, i] = F[l - 1, i] * (1 - r[l]) + r_n[l]
emission_index = core.get_index_in_emission_matrix(
emission_index = core.get_index_in_emission_matrix_haploid(
ref_allele=H[l, i], query_allele=s[0, l]
)
F[l, i] *= e[l, emission_index]
Expand All @@ -44,9 +44,8 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):

else:
c = np.ones(m)

for i in range(n):
emission_index = core.get_index_in_emission_matrix(
emission_index = core.get_index_in_emission_matrix_haploid(
ref_allele=H[0, i], query_allele=s[0, 0]
)
F[0, i] = 1 / n * e[0, emission_index]
Expand All @@ -55,7 +54,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
for l in range(1, m):
for i in range(n):
F[l, i] = F[l - 1, i] * (1 - r[l]) + np.sum(F[l - 1, :]) * r_n[l]
emission_index = core.get_index_in_emission_matrix(
emission_index = core.get_index_in_emission_matrix_haploid(
ref_allele=H[l, i], query_allele=s[0, l]
)
F[l, i] *= e[l, emission_index]
Expand All @@ -67,7 +66,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):

@jit.numba_njit
def backwards_ls_hap(n, m, H, s, e, c, r):
"""A matrix-based implementation using Numpy vectorisation."""
"""A matrix-based implementation using Numpy."""
B = np.zeros((m, n))
for i in range(n):
B[m - 1, i] = 1
Expand All @@ -78,7 +77,7 @@ def backwards_ls_hap(n, m, H, s, e, c, r):
tmp_B = np.zeros(n)
tmp_B_sum = 0
for i in range(n):
emission_index = core.get_index_in_emission_matrix(
emission_index = core.get_index_in_emission_matrix_haploid(
ref_allele=H[l + 1, i], query_allele=s[0, l + 1]
)
tmp_B[i] = e[l + 1, emission_index] * B[l + 1, i]
Expand Down
4 changes: 2 additions & 2 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ def get_multiallelic_n16(self, seed=42):
return ts

# Prepare a larger example dataset.
def get_larger(self, num_samples, seq_length, mean_r, mean_mu, seed=42):
def get_larger(self, num_samples, length, mean_r, mean_mu, seed=42):
ts = msprime.simulate(
num_samples + 1,
length=seq_length,
length=length,
mutation_rate=mean_mu,
recombination_rate=mean_r,
random_seed=seed,
Expand Down
Loading

0 comments on commit b6c9c71

Please sign in to comment.