diff --git a/lshmm/core.py b/lshmm/core.py index ae0fbfc..648fed5 100644 --- a/lshmm/core.py +++ b/lshmm/core.py @@ -173,25 +173,32 @@ def get_index_in_emission_matrix_haploid(ref_allele, query_allele): @jit.numba_njit -def get_index_in_emission_matrix_diploid(ref_allele, query_allele): - if query_allele == MISSING: +def get_index_in_emission_matrix_diploid(ref_genotype, query_genotype): + """ + Compare the implied unphased genotypes (allele dosages) of + the reference and query to get the index of the entry + in the emission probability matrix, and return the index. + """ + if query_genotype == MISSING: return MISSING_INDEX else: - is_match = ref_allele == query_allele - is_ref_one = ref_allele == 1 - is_query_one = query_allele == 1 - return 4 * is_match + 2 * is_ref_one + is_query_one + is_match = ref_genotype == query_genotype + is_ref_het = ref_genotype == 1 + is_query_het = query_genotype == 1 + return 4 * is_match + 2 * is_ref_het + is_query_het @jit.numba_njit -def get_index_in_emission_matrix_diploid_G(ref_G, query_allele, n): - if query_allele == MISSING: - return MISSING_INDEX * np.ones((n, n), dtype=np.int64) +def get_index_in_emission_matrix_diploid_genotypes( + ref_genotypes, query_genotype, num_ref_haps +): + if query_genotype == MISSING: + return MISSING_INDEX * np.ones((num_ref_haps, num_ref_haps), dtype=np.int64) else: - is_match = ref_G == query_allele - is_ref_one = ref_G == 1 - is_query_one = query_allele == 1 - return 4 * is_match + 2 * is_ref_one + is_query_one + is_match = ref_genotypes == query_genotype + is_ref_het = ref_genotypes == 1 + is_query_het = query_genotype == 1 + return 4 * is_match + 2 * is_ref_het + is_query_het def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate): @@ -264,24 +271,32 @@ def get_emission_probability_haploid(ref_allele, query_allele, site, emission_ma @jit.numba_njit -def get_emission_probability_diploid(ref_allele, query_allele, site, emission_matrix): - if ref_allele == NONCOPY: +def get_emission_probability_diploid( + ref_genotype, query_genotype, site, emission_matrix +): + if ref_genotype == NONCOPY: return 0.0 else: - emission_index = get_index_in_emission_matrix_diploid(ref_allele, query_allele) + emission_index = get_index_in_emission_matrix_diploid( + ref_genotype, query_genotype + ) return emission_matrix[site, emission_index] @jit.numba_njit -def get_emission_probability_diploid_G(ref_G, query_allele, site, emission_matrix): - emission_probs = np.zeros(ref_G.shape, dtype=np.float64) - for i in range(len(ref_G)): - for j in range(len(ref_G)): - if ref_G[i, j] == NONCOPY: +def get_emission_probability_diploid_genotypes( + ref_genotypes, query_genotype, site, emission_matrix +): + assert ref_genotypes.shape[0] == ref_genotypes.shape[1] + num_ref_haps = len(ref_genotypes) + emission_probs = np.zeros((num_ref_haps, num_ref_haps), dtype=np.float64) + for i in range(num_ref_haps): + for j in range(num_ref_haps): + if ref_genotypes[i, j] == NONCOPY: emission_probs[i, j] = 0.0 else: emission_index = get_index_in_emission_matrix_diploid( - ref_G[i, j], query_allele + ref_genotypes[i, j], query_genotype ) emission_probs[i, j] = emission_matrix[site, emission_index] return emission_probs diff --git a/lshmm/fb_diploid.py b/lshmm/fb_diploid.py index 9de6564..42a831b 100644 --- a/lshmm/fb_diploid.py +++ b/lshmm/fb_diploid.py @@ -17,9 +17,9 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): c = np.ones(m) r_n = r / n - emission_probs = core.get_emission_probability_diploid_G( - ref_G=G[0, :, :], - query_allele=s[0, 0], + emission_probs = core.get_emission_probability_diploid_genotypes( + ref_genotypes=G[0, :, :], + query_genotype=s[0, 0], site=0, emission_matrix=e, ) @@ -31,9 +31,9 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): # Forwards for l in range(1, m): - emission_probs = core.get_emission_probability_diploid_G( - ref_G=G[l, :, :], - query_allele=s[0, l], + emission_probs = core.get_emission_probability_diploid_genotypes( + ref_genotypes=G[l, :, :], + query_genotype=s[0, l], site=l, emission_matrix=e, ) @@ -57,9 +57,9 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): else: # Forwards for l in range(1, m): - emission_probs = core.get_emission_probability_diploid_G( - ref_G=G[l, :, :], - query_allele=s[0, l], + emission_probs = core.get_emission_probability_diploid_genotypes( + ref_genotypes=G[l, :, :], + query_genotype=s[0, l], site=l, emission_matrix=e, ) @@ -92,9 +92,9 @@ def backwards_ls_dip(n, m, G, s, e, c, r): # Backwards for l in range(m - 2, -1, -1): - emission_probs = core.get_emission_probability_diploid_G( - ref_G=G[l + 1, :, :], - query_allele=s[0, l + 1], + emission_probs = core.get_emission_probability_diploid_genotypes( + ref_genotypes=G[l + 1, :, :], + query_genotype=s[0, l + 1], site=l + 1, emission_matrix=e, ) @@ -126,8 +126,8 @@ def forward_ls_dip_starting_point(n, m, G, s, e, r): for j2 in range(n): F[0, j1, j2] = 1 / (n**2) emission_prob = core.get_emission_probability_diploid( - ref_allele=G[0, j1, j2], - query_allele=s[0, 0], + ref_genotype=G[0, j1, j2], + query_genotype=s[0, 0], site=0, emission_matrix=e, ) @@ -170,8 +170,8 @@ def forward_ls_dip_starting_point(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): emission_prob = core.get_emission_probability_diploid( - ref_allele=G[l, j1, j2], - query_allele=s[0, l], + ref_genotype=G[l, j1, j2], + query_genotype=s[0, l], site=l, emission_matrix=e, ) @@ -201,8 +201,8 @@ def backward_ls_dip_starting_point(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): emission_prob = core.get_emission_probability_diploid( - ref_allele=G[l + 1, j1, j2], - query_allele=s[0, l + 1], + ref_genotype=G[l + 1, j1, j2], + query_genotype=s[0, l + 1], site=l + 1, emission_matrix=e, ) @@ -258,8 +258,8 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): for j2 in range(n): F[0, j1, j2] = 1 / (n**2) emission_prob = core.get_emission_probability_diploid( - ref_allele=G[0, j1, j2], - query_allele=s[0, 0], + ref_genotype=G[0, j1, j2], + query_genotype=s[0, 0], site=0, emission_matrix=e, ) @@ -291,8 +291,8 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): for j1 in range(n): for j2 in range(n): emission_prob = core.get_emission_probability_diploid( - ref_allele=G[l, j1, j2], - query_allele=s[0, l], + ref_genotype=G[l, j1, j2], + query_genotype=s[0, l], site=l, emission_matrix=e, ) @@ -328,8 +328,8 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): for j1 in range(n): for j2 in range(n): emission_prob = core.get_emission_probability_diploid( - ref_allele=G[l, j1, j2], - query_allele=s[0, l], + ref_genotype=G[l, j1, j2], + query_genotype=s[0, l], site=l, emission_matrix=e, ) @@ -363,8 +363,8 @@ def backward_ls_dip_loop(n, m, G, s, e, c, r): for j1 in range(n): for j2 in range(n): emission_prob = core.get_emission_probability_diploid( - ref_allele=G[l + 1, j1, j2], - query_allele=s[0, l + 1], + ref_genotype=G[l + 1, j1, j2], + query_genotype=s[0, l + 1], site=l + 1, emission_matrix=e, ) diff --git a/lshmm/vit_diploid.py b/lshmm/vit_diploid.py index 031674a..70c1001 100644 --- a/lshmm/vit_diploid.py +++ b/lshmm/vit_diploid.py @@ -21,17 +21,17 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): emission_prob = core.get_emission_probability_diploid( - ref_allele=G[0, j1, j2], - query_allele=s[0, 0], + ref_genotype=G[0, j1, j2], + query_genotype=s[0, 0], site=0, emission_matrix=e, ) V[0, j1, j2] = 1 / (n**2) * emission_prob for l in range(1, m): - emission_probs = core.get_emission_probability_diploid_G( - ref_G=G[l, :, :], - query_allele=s[0, l], + emission_probs = core.get_emission_probability_diploid_genotypes( + ref_genotypes=G[l, :, :], + query_genotype=s[0, l], site=l, emission_matrix=e, ) @@ -73,8 +73,8 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): emission_prob = core.get_emission_probability_diploid( - ref_allele=G[0, j1, j2], - query_allele=s[0, 0], + ref_genotype=G[0, j1, j2], + query_genotype=s[0, 0], site=0, emission_matrix=e, ) @@ -84,9 +84,9 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): # see if we can pinch some ideas. # Diploid Viterbi, with smaller memory footprint. for l in range(1, m): - emission_probs = core.get_emission_probability_diploid_G( - ref_G=G[l, :, :], - query_allele=s[0, l], + emission_probs = core.get_emission_probability_diploid_genotypes( + ref_genotypes=G[l, :, :], + query_genotype=s[0, l], site=l, emission_matrix=e, ) @@ -132,8 +132,8 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): emission_prob = core.get_emission_probability_diploid( - ref_allele=G[0, j1, j2], - query_allele=s[0, 0], + ref_genotype=G[0, j1, j2], + query_genotype=s[0, 0], site=0, emission_matrix=e, ) @@ -141,9 +141,9 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): # Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM. for l in range(1, m): - emission_probs = core.get_emission_probability_diploid_G( - ref_G=G[l, :, :], - query_allele=s[0, l], + emission_probs = core.get_emission_probability_diploid_genotypes( + ref_genotypes=G[l, :, :], + query_genotype=s[0, l], site=l, emission_matrix=e, ) @@ -221,8 +221,8 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): emission_prob = core.get_emission_probability_diploid( - ref_allele=G[0, j1, j2], - query_allele=s[0, 0], + ref_genotype=G[0, j1, j2], + query_genotype=s[0, 0], site=0, emission_matrix=e, ) @@ -230,9 +230,9 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): # Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM. for l in range(1, m): - emission_probs = core.get_emission_probability_diploid_G( - ref_G=G[l, :, :], - query_allele=s[0, l], + emission_probs = core.get_emission_probability_diploid_genotypes( + ref_genotypes=G[l, :, :], + query_genotype=s[0, l], site=l, emission_matrix=e, ) @@ -303,8 +303,8 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): emission_prob = core.get_emission_probability_diploid( - ref_allele=G[0, j1, j2], - query_allele=s[0, 0], + ref_genotype=G[0, j1, j2], + query_genotype=s[0, 0], site=0, emission_matrix=e, ) @@ -312,9 +312,9 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r): # Jumped the gun - vectorising. for l in range(1, m): - emission_probs = core.get_emission_probability_diploid_G( - ref_G=G[l, :, :], - query_allele=s[0, l], + emission_probs = core.get_emission_probability_diploid_genotypes( + ref_genotypes=G[l, :, :], + query_genotype=s[0, l], site=l, emission_matrix=e, ) @@ -349,9 +349,9 @@ def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r): P = np.zeros((m, n, n), dtype=np.int64) c = np.ones(m) - emission_probs = core.get_emission_probability_diploid_G( - ref_G=G[0, :, :], - query_allele=s[0, 0], + emission_probs = core.get_emission_probability_diploid_genotypes( + ref_genotypes=G[0, :, :], + query_genotype=s[0, 0], site=l, emission_matrix=e, ) @@ -359,9 +359,9 @@ def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r): r_n = r / n for l in range(1, m): - emission_probs = core.get_emission_probability_diploid_G( - ref_G=G[l, :, :], - query_allele=s[0, l], + emission_probs = core.get_emission_probability_diploid_genotypes( + ref_genotypes=G[l, :, :], + query_genotype=s[0, l], site=l, emission_matrix=e, ) @@ -460,8 +460,8 @@ def path_ll_dip(n, m, G, phased_path, s, e, r): This is exposed via the API. """ emission_prob = core.get_emission_probability_diploid( - ref_allele=G[0, phased_path[0][0], phased_path[1][0]], - query_allele=s[0, 0], + ref_genotype=G[0, phased_path[0][0], phased_path[1][0]], + query_genotype=s[0, 0], site=0, emission_matrix=e, ) @@ -472,8 +472,8 @@ def path_ll_dip(n, m, G, phased_path, s, e, r): for l in range(1, m): emission_prob = core.get_emission_probability_diploid( - ref_allele=G[l, phased_path[0][l], phased_path[1][l]], - query_allele=s[0, l], + ref_genotype=G[l, phased_path[0][l], phased_path[1][l]], + query_genotype=s[0, l], site=l, emission_matrix=e, )