Skip to content

Commit

Permalink
Rename variables in functions assigning emission probabilities for di…
Browse files Browse the repository at this point in the history
…ploid case
  • Loading branch information
szhan committed Jun 10, 2024
1 parent 8f5cf2d commit 5717033
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 83 deletions.
59 changes: 37 additions & 22 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,25 +173,32 @@ def get_index_in_emission_matrix_haploid(ref_allele, query_allele):


@jit.numba_njit
def get_index_in_emission_matrix_diploid(ref_allele, query_allele):
if query_allele == MISSING:
def get_index_in_emission_matrix_diploid(ref_genotype, query_genotype):
"""
Compare the implied unphased genotypes (allele dosages) of
the reference and query to get the index of the entry
in the emission probability matrix, and return the index.
"""
if query_genotype == MISSING:
return MISSING_INDEX
else:
is_match = ref_allele == query_allele
is_ref_one = ref_allele == 1
is_query_one = query_allele == 1
return 4 * is_match + 2 * is_ref_one + is_query_one
is_match = ref_genotype == query_genotype
is_ref_het = ref_genotype == 1
is_query_het = query_genotype == 1
return 4 * is_match + 2 * is_ref_het + is_query_het


@jit.numba_njit
def get_index_in_emission_matrix_diploid_G(ref_G, query_allele, n):
if query_allele == MISSING:
return MISSING_INDEX * np.ones((n, n), dtype=np.int64)
def get_index_in_emission_matrix_diploid_genotypes(
ref_genotypes, query_genotype, num_ref_haps
):
if query_genotype == MISSING:
return MISSING_INDEX * np.ones((num_ref_haps, num_ref_haps), dtype=np.int64)
else:
is_match = ref_G == query_allele
is_ref_one = ref_G == 1
is_query_one = query_allele == 1
return 4 * is_match + 2 * is_ref_one + is_query_one
is_match = ref_genotypes == query_genotype
is_ref_het = ref_genotypes == 1
is_query_het = query_genotype == 1
return 4 * is_match + 2 * is_ref_het + is_query_het


def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate):
Expand Down Expand Up @@ -264,24 +271,32 @@ def get_emission_probability_haploid(ref_allele, query_allele, site, emission_ma


@jit.numba_njit
def get_emission_probability_diploid(ref_allele, query_allele, site, emission_matrix):
if ref_allele == NONCOPY:
def get_emission_probability_diploid(
ref_genotype, query_genotype, site, emission_matrix
):
if ref_genotype == NONCOPY:
return 0.0
else:
emission_index = get_index_in_emission_matrix_diploid(ref_allele, query_allele)
emission_index = get_index_in_emission_matrix_diploid(
ref_genotype, query_genotype
)
return emission_matrix[site, emission_index]


@jit.numba_njit
def get_emission_probability_diploid_G(ref_G, query_allele, site, emission_matrix):
emission_probs = np.zeros(ref_G.shape, dtype=np.float64)
for i in range(len(ref_G)):
for j in range(len(ref_G)):
if ref_G[i, j] == NONCOPY:
def get_emission_probability_diploid_genotypes(
ref_genotypes, query_genotype, site, emission_matrix
):
assert ref_genotypes.shape[0] == ref_genotypes.shape[1]
num_ref_haps = len(ref_genotypes)
emission_probs = np.zeros((num_ref_haps, num_ref_haps), dtype=np.float64)
for i in range(num_ref_haps):
for j in range(num_ref_haps):
if ref_genotypes[i, j] == NONCOPY:
emission_probs[i, j] = 0.0
else:
emission_index = get_index_in_emission_matrix_diploid(
ref_G[i, j], query_allele
ref_genotypes[i, j], query_genotype
)
emission_probs[i, j] = emission_matrix[site, emission_index]
return emission_probs
52 changes: 26 additions & 26 deletions lshmm/fb_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True):
c = np.ones(m)
r_n = r / n

emission_probs = core.get_emission_probability_diploid_G(
ref_G=G[0, :, :],
query_allele=s[0, 0],
emission_probs = core.get_emission_probability_diploid_genotypes(
ref_genotypes=G[0, :, :],
query_genotype=s[0, 0],
site=0,
emission_matrix=e,
)
Expand All @@ -31,9 +31,9 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True):

# Forwards
for l in range(1, m):
emission_probs = core.get_emission_probability_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
emission_probs = core.get_emission_probability_diploid_genotypes(
ref_genotypes=G[l, :, :],
query_genotype=s[0, l],
site=l,
emission_matrix=e,
)
Expand All @@ -57,9 +57,9 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True):
else:
# Forwards
for l in range(1, m):
emission_probs = core.get_emission_probability_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
emission_probs = core.get_emission_probability_diploid_genotypes(
ref_genotypes=G[l, :, :],
query_genotype=s[0, l],
site=l,
emission_matrix=e,
)
Expand Down Expand Up @@ -92,9 +92,9 @@ def backwards_ls_dip(n, m, G, s, e, c, r):

# Backwards
for l in range(m - 2, -1, -1):
emission_probs = core.get_emission_probability_diploid_G(
ref_G=G[l + 1, :, :],
query_allele=s[0, l + 1],
emission_probs = core.get_emission_probability_diploid_genotypes(
ref_genotypes=G[l + 1, :, :],
query_genotype=s[0, l + 1],
site=l + 1,
emission_matrix=e,
)
Expand Down Expand Up @@ -126,8 +126,8 @@ def forward_ls_dip_starting_point(n, m, G, s, e, r):
for j2 in range(n):
F[0, j1, j2] = 1 / (n**2)
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0],
ref_genotype=G[0, j1, j2],
query_genotype=s[0, 0],
site=0,
emission_matrix=e,
)
Expand Down Expand Up @@ -170,8 +170,8 @@ def forward_ls_dip_starting_point(n, m, G, s, e, r):
for j1 in range(n):
for j2 in range(n):
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[l, j1, j2],
query_allele=s[0, l],
ref_genotype=G[l, j1, j2],
query_genotype=s[0, l],
site=l,
emission_matrix=e,
)
Expand Down Expand Up @@ -201,8 +201,8 @@ def backward_ls_dip_starting_point(n, m, G, s, e, r):
for j1 in range(n):
for j2 in range(n):
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[l + 1, j1, j2],
query_allele=s[0, l + 1],
ref_genotype=G[l + 1, j1, j2],
query_genotype=s[0, l + 1],
site=l + 1,
emission_matrix=e,
)
Expand Down Expand Up @@ -258,8 +258,8 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True):
for j2 in range(n):
F[0, j1, j2] = 1 / (n**2)
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0],
ref_genotype=G[0, j1, j2],
query_genotype=s[0, 0],
site=0,
emission_matrix=e,
)
Expand Down Expand Up @@ -291,8 +291,8 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True):
for j1 in range(n):
for j2 in range(n):
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[l, j1, j2],
query_allele=s[0, l],
ref_genotype=G[l, j1, j2],
query_genotype=s[0, l],
site=l,
emission_matrix=e,
)
Expand Down Expand Up @@ -328,8 +328,8 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True):
for j1 in range(n):
for j2 in range(n):
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[l, j1, j2],
query_allele=s[0, l],
ref_genotype=G[l, j1, j2],
query_genotype=s[0, l],
site=l,
emission_matrix=e,
)
Expand Down Expand Up @@ -363,8 +363,8 @@ def backward_ls_dip_loop(n, m, G, s, e, c, r):
for j1 in range(n):
for j2 in range(n):
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[l + 1, j1, j2],
query_allele=s[0, l + 1],
ref_genotype=G[l + 1, j1, j2],
query_genotype=s[0, l + 1],
site=l + 1,
emission_matrix=e,
)
Expand Down
70 changes: 35 additions & 35 deletions lshmm/vit_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r):
for j1 in range(n):
for j2 in range(n):
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0],
ref_genotype=G[0, j1, j2],
query_genotype=s[0, 0],
site=0,
emission_matrix=e,
)
V[0, j1, j2] = 1 / (n**2) * emission_prob

for l in range(1, m):
emission_probs = core.get_emission_probability_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
emission_probs = core.get_emission_probability_diploid_genotypes(
ref_genotypes=G[l, :, :],
query_genotype=s[0, l],
site=l,
emission_matrix=e,
)
Expand Down Expand Up @@ -73,8 +73,8 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r):
for j1 in range(n):
for j2 in range(n):
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0],
ref_genotype=G[0, j1, j2],
query_genotype=s[0, 0],
site=0,
emission_matrix=e,
)
Expand All @@ -84,9 +84,9 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r):
# see if we can pinch some ideas.
# Diploid Viterbi, with smaller memory footprint.
for l in range(1, m):
emission_probs = core.get_emission_probability_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
emission_probs = core.get_emission_probability_diploid_genotypes(
ref_genotypes=G[l, :, :],
query_genotype=s[0, l],
site=l,
emission_matrix=e,
)
Expand Down Expand Up @@ -132,18 +132,18 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r):
for j1 in range(n):
for j2 in range(n):
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0],
ref_genotype=G[0, j1, j2],
query_genotype=s[0, 0],
site=0,
emission_matrix=e,
)
V_prev[j1, j2] = 1 / (n**2) * emission_prob

# Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM.
for l in range(1, m):
emission_probs = core.get_emission_probability_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
emission_probs = core.get_emission_probability_diploid_genotypes(
ref_genotypes=G[l, :, :],
query_genotype=s[0, l],
site=l,
emission_matrix=e,
)
Expand Down Expand Up @@ -221,18 +221,18 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r):
for j1 in range(n):
for j2 in range(n):
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0],
ref_genotype=G[0, j1, j2],
query_genotype=s[0, 0],
site=0,
emission_matrix=e,
)
V_prev[j1, j2] = 1 / (n**2) * emission_prob

# Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM.
for l in range(1, m):
emission_probs = core.get_emission_probability_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
emission_probs = core.get_emission_probability_diploid_genotypes(
ref_genotypes=G[l, :, :],
query_genotype=s[0, l],
site=l,
emission_matrix=e,
)
Expand Down Expand Up @@ -303,18 +303,18 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r):
for j1 in range(n):
for j2 in range(n):
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[0, j1, j2],
query_allele=s[0, 0],
ref_genotype=G[0, j1, j2],
query_genotype=s[0, 0],
site=0,
emission_matrix=e,
)
V[0, j1, j2] = 1 / (n**2) * emission_prob

# Jumped the gun - vectorising.
for l in range(1, m):
emission_probs = core.get_emission_probability_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
emission_probs = core.get_emission_probability_diploid_genotypes(
ref_genotypes=G[l, :, :],
query_genotype=s[0, l],
site=l,
emission_matrix=e,
)
Expand Down Expand Up @@ -349,19 +349,19 @@ def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r):
P = np.zeros((m, n, n), dtype=np.int64)
c = np.ones(m)

emission_probs = core.get_emission_probability_diploid_G(
ref_G=G[0, :, :],
query_allele=s[0, 0],
emission_probs = core.get_emission_probability_diploid_genotypes(
ref_genotypes=G[0, :, :],
query_genotype=s[0, 0],
site=l,
emission_matrix=e,
)
V[0, :, :] = 1 / (n**2) * emission_probs
r_n = r / n

for l in range(1, m):
emission_probs = core.get_emission_probability_diploid_G(
ref_G=G[l, :, :],
query_allele=s[0, l],
emission_probs = core.get_emission_probability_diploid_genotypes(
ref_genotypes=G[l, :, :],
query_genotype=s[0, l],
site=l,
emission_matrix=e,
)
Expand Down Expand Up @@ -460,8 +460,8 @@ def path_ll_dip(n, m, G, phased_path, s, e, r):
This is exposed via the API.
"""
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[0, phased_path[0][0], phased_path[1][0]],
query_allele=s[0, 0],
ref_genotype=G[0, phased_path[0][0], phased_path[1][0]],
query_genotype=s[0, 0],
site=0,
emission_matrix=e,
)
Expand All @@ -472,8 +472,8 @@ def path_ll_dip(n, m, G, phased_path, s, e, r):

for l in range(1, m):
emission_prob = core.get_emission_probability_diploid(
ref_allele=G[l, phased_path[0][l], phased_path[1][l]],
query_allele=s[0, l],
ref_genotype=G[l, phased_path[0][l], phased_path[1][l]],
query_genotype=s[0, l],
site=l,
emission_matrix=e,
)
Expand Down

0 comments on commit 5717033

Please sign in to comment.