diff --git a/lshmm/core.py b/lshmm/core.py index dd822a1..79c0f08 100644 --- a/lshmm/core.py +++ b/lshmm/core.py @@ -68,7 +68,18 @@ def get_index_in_emission_prob_matrix_diploid(ref_allele, query_allele): if query_allele == MISSING: return MISSING_INDEX else: - is_allele_match = ref_allele == query_allele + is_match = ref_allele == query_allele is_ref_one = ref_allele == 1 is_query_one = query_allele == 1 - return 4 * is_allele_match + 2 * is_ref_one + is_query_one + return 4 * is_match + 2 * is_ref_one + is_query_one + + +@jit.numba_njit +def get_index_in_emission_prob_matrix_diploid_G(ref_G, query_allele, n): + if query_allele == MISSING: + return MISSING_INDEX * np.ones((n, n), dtype=np.int64) + else: + is_match = ref_G == query_allele + is_ref_one = ref_G == 1 + is_query_one = query_allele == 1 + return 4 * is_match + 2 * is_ref_one + is_query_one diff --git a/lshmm/vit_diploid.py b/lshmm/vit_diploid.py index e1c0304..83312fa 100644 --- a/lshmm/vit_diploid.py +++ b/lshmm/vit_diploid.py @@ -26,14 +26,11 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r): V[0, j1, j2] = 1 / (n**2) * e[0, emission_index] for l in range(1, m): - if s[0, l] == core.MISSING: - index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) - + 2 * (G[l, :, :] == 1).astype(np.int64) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_prob_matrix_diploid_G( + ref_G=G[l, :, :], + query_allele=s[0, l], + n=n, + ) for j1 in range(n): for j2 in range(n): @@ -50,7 +47,7 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r): v[k1, k2] *= r_n[l] * (1 - r[l]) + r_n[l] ** 2 else: v[k1, k2] *= r_n[l] ** 2 - V[l, j1, j2] = np.amax(v) * e[l, index[j1, j2]] + V[l, j1, j2] = np.amax(v) * e[l, emission_index[j1, j2]] P[l, j1, j2] = np.argmax(v) c[l] = np.amax(V[l, :, :]) V[l, :, :] *= 1 / c[l] @@ -81,17 +78,14 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): # see if we can pinch some ideas. # Diploid Viterbi, with smaller memory footprint. for l in range(1, m): - if s[0, l] == core.MISSING: - index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) - + 2 * (G[l, :, :] == 1).astype(np.int64) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_prob_matrix_diploid_G( + ref_G=G[l, :, :], + query_allele=s[0, l], + n=n, + ) + for j1 in range(n): for j2 in range(n): - # Get the vector to maximise over v = np.zeros((n, n)) for k1 in range(n): for k2 in range(n): @@ -104,7 +98,7 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): v[k1, k2] *= r_n[l] * (1 - r[l]) + r_n[l] ** 2 else: v[k1, k2] *= r_n[l] ** 2 - V[j1, j2] = np.amax(v) * e[l, index[j1, j2]] + V[j1, j2] = np.amax(v) * e[l, emission_index[j1, j2]] P[l, j1, j2] = np.argmax(v) c[l] = np.amax(V) V_prev = np.copy(V) / c[l] @@ -133,14 +127,11 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): # Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM. for l in range(1, m): - if s[0, l] == core.MISSING: - index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) - + 2 * (G[l, :, :] == 1).astype(np.int64) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_prob_matrix_diploid_G( + ref_G=G[l, :, :], + query_allele=s[0, l], + n=n, + ) c[l] = np.amax(V_prev) argmax = np.argmax(V_prev) @@ -183,7 +174,7 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): V[j1, j2] = double_switch P[l, j1, j2] = argmax - V[j1, j2] *= e[l, index[j1, j2]] + V[j1, j2] *= e[l, emission_index[j1, j2]] j1_j2 += 1 V_prev = np.copy(V) @@ -221,14 +212,11 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): # Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM. for l in range(1, m): - if s[0, l] == core.MISSING: - index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) - + 2 * (G[l, :, :] == 1).astype(np.int64) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_prob_matrix_diploid_G( + ref_G=G[l, :, :], + query_allele=s[0, l], + n=n, + ) c[l] = np.amax(V_prev) argmax = np.argmax(V_prev) @@ -264,7 +252,7 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): V[j1, j2] = double_switch recombs_double[l] = np.append(recombs_double[l], values=j1_j2) - V[j1, j2] *= e[l, index[j1, j2]] + V[j1, j2] *= e[l, emission_index[j1, j2]] j1_j2 += 1 V_prev = np.copy(V) @@ -302,14 +290,11 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r): # Jumped the gun - vectorising. for l in range(1, m): - if s[0, l] == core.MISSING: - index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) - + 2 * (G[l, :, :] == 1).astype(np.int64) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_prob_matrix_diploid_G( + ref_G=G[l, :, :], + query_allele=s[0, l], + n=n, + ) for j1 in range(n): for j2 in range(n): @@ -318,7 +303,7 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r): v[j1, :] += r_n[l] * (1 - r[l]) v[:, j2] += r_n[l] * (1 - r[l]) v *= V[l - 1, :, :] - V[l, j1, j2] = np.amax(v) * e[l, index[j1, j2]] + V[l, j1, j2] = np.amax(v) * e[l, emission_index[j1, j2]] P[l, j1, j2] = np.argmax(v) c[l] = np.amax(V[l, :, :]) @@ -341,26 +326,20 @@ def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r): P = np.zeros((m, n, n), dtype=np.int64) c = np.ones(m) - if s[0, 0] == core.MISSING: - index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[0, :, :], s[0, 0]).astype(np.int64) - + 2 * (G[0, :, :] == 1).astype(np.int64) - + np.int64(s[0, 0] == 1) - ) - V[0, :, :] = 1 / (n**2) * e[0, index] + emission_index = core.get_index_in_emission_prob_matrix_diploid_G( + ref_G=G[0, :, :], + query_allele=s[0, 0], + n=n, + ) + V[0, :, :] = 1 / (n**2) * e[0, emission_index] r_n = r / n for l in range(1, m): - if s[0, l] == core.MISSING: - index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) - + 2 * (G[l, :, :] == 1).astype(np.int64) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_prob_matrix_diploid_G( + ref_G=G[l, :, :], + query_allele=s[0, l], + n=n, + ) v = ( (r_n[l] ** 2) + (1 - r[l]) ** 2 * char_both @@ -368,7 +347,7 @@ def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r): ) v *= V[l - 1, :, :] P[l, :, :] = np.argmax(v.reshape(n, n, -1), 2) # Have to flatten to use argmax - V[l, :, :] = v.reshape(n, n, -1)[rows, cols, P[l, :, :]] * e[l, index] + V[l, :, :] = v.reshape(n, n, -1)[rows, cols, P[l, :, :]] * e[l, emission_index] c[l] = np.amax(V[l, :, :]) V[l, :, :] *= 1 / c[l] @@ -455,17 +434,10 @@ def path_ll_dip(n, m, G, phased_path, s, e, r): r_n = r / n for l in range(1, m): - if s[0, l] == core.MISSING: - index = core.MISSING_INDEX - else: - index = ( - 4 - * np.int64( - np.equal(G[l, phased_path[0][l], phased_path[1][l]], s[0, l]) - ) - + 2 * np.int64(G[l, phased_path[0][l], phased_path[1][l]] == 1) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_prob_matrix_diploid( + ref_allele=G[l, phased_path[0][l], phased_path[1][l]], + query_allele=s[0, l], + ) current_phase = np.array([phased_path[0][l], phased_path[1][l]]) phase_diff = np.sum(~np.equal(current_phase, old_phase)) @@ -479,7 +451,7 @@ def path_ll_dip(n, m, G, phased_path, s, e, r): else: log_prob_path += np.log10(r_n[l] ** 2) - log_prob_path += np.log10(e[l, index]) + log_prob_path += np.log10(e[l, emission_index]) old_phase = current_phase return log_prob_path diff --git a/tests/test_API_multiallelic.py b/tests/test_API_multiallelic.py index 68a7707..5c01c3d 100644 --- a/tests/test_API_multiallelic.py +++ b/tests/test_API_multiallelic.py @@ -145,11 +145,7 @@ def test_simple_n_8(self): population_size=10000, random_seed=42, ) - ts = msprime.sim_mutations( - ts, - rate=1e-4, - random_seed=42 - ) + ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42) assert ts.num_sites > 5 assert ts.num_trees > 15 self.verify(ts) diff --git a/tests/test_LS_haploid_diploid.py b/tests/test_LS_haploid_diploid.py index a83943a..086d79d 100644 --- a/tests/test_LS_haploid_diploid.py +++ b/tests/test_LS_haploid_diploid.py @@ -36,14 +36,12 @@ def example_haplotypes(self, ts): return H, haplotypes def haplotype_emission(self, mu, m): - # Define the emission probability matrix e = np.zeros((m, 2)) - e[:, 0] = mu # If they match - e[:, 1] = 1 - mu # If they don't match + e[:, 0] = mu + e[:, 1] = 1 - mu return e def genotype_emission(self, mu, m): - # Define the emission probability matrix e = np.zeros((m, 8)) e[:, core.EQUAL_BOTH_HOM] = (1 - mu) ** 2 e[:, core.UNEQUAL_BOTH_HOM] = mu**2 @@ -92,10 +90,8 @@ def example_parameters_haplotypes_larger( r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) r[0] = 0 - # Error probability mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - # Define the emission probability matrix e = self.haplotype_emission(mu, m) for s in haplotypes: @@ -167,46 +163,65 @@ def example_parameters_genotypes_larger( r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) r[0] = 0 - # Error probability mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - # Define the emission probability matrix e = self.genotype_emission(mu, m) for s in genotypes: yield n, m, G, s, e, r def assertAllClose(self, A, B): - """Assert that all entries of two matrices are 'close'""" - # assert np.allclose(A, B, rtol=1e-9, atol=0.0) assert np.allclose(A, B, rtol=1e-09, atol=1e-08) # Define a bunch of very small tree-sequences for testing a collection of parameters on def test_simple_n_10_no_recombination(self): ts = msprime.simulate( - 10, recombination_rate=0, mutation_rate=0.5, random_seed=42 + 10, + recombination_rate=0, + mutation_rate=0.5, + random_seed=42, ) assert ts.num_sites > 3 self.verify(ts) def test_simple_n_6(self): - ts = msprime.simulate(6, recombination_rate=2, mutation_rate=7, random_seed=42) + ts = msprime.simulate( + 6, + recombination_rate=2, + mutation_rate=7, + random_seed=42, + ) assert ts.num_sites > 5 self.verify(ts) def test_simple_n_8(self): - ts = msprime.simulate(8, recombination_rate=2, mutation_rate=5, random_seed=42) + ts = msprime.simulate( + 8, + recombination_rate=2, + mutation_rate=5, + random_seed=42, + ) assert ts.num_sites > 5 self.verify(ts) def test_simple_n_8_high_recombination(self): - ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=42) + ts = msprime.simulate( + 8, + recombination_rate=20, + mutation_rate=5, + random_seed=42, + ) assert ts.num_trees > 15 assert ts.num_sites > 5 self.verify(ts) def test_simple_n_16(self): - ts = msprime.simulate(16, recombination_rate=2, mutation_rate=5, random_seed=42) + ts = msprime.simulate( + 16, + recombination_rate=2, + mutation_rate=5, + random_seed=42, + ) assert ts.num_sites > 5 self.verify(ts) @@ -233,7 +248,7 @@ class FBAlgorithmBase(LSBase): class TestNonTreeMethodsHap(FBAlgorithmBase): - """Test that we compute the sample likelihoods across all implementations.""" + """Test that the computed likelihoods are the same across all implementations.""" def verify(self, ts): for n, m, H_vs, s, e_vs, r in self.example_parameters_haplotypes(ts): @@ -269,7 +284,7 @@ def verify_larger(self, ts): class TestNonTreeMethodsDip(FBAlgorithmBase): - """Test that we compute the sample likelihoods across all implementations.""" + """Test that the computed likelihoods are the same across all implementations.""" def verify(self, ts): for n, m, G_vs, s, e_vs, r in self.example_parameters_genotypes(ts): @@ -356,7 +371,7 @@ class VitAlgorithmBase(LSBase): class TestNonTreeViterbiHap(VitAlgorithmBase): - """Test that we have the same log-likelihood across all implementations""" + """Test that the computed log-likelihoods are the same across all implementations.""" def verify(self, ts): for n, m, H_vs, s, e_vs, r in self.example_parameters_haplotypes(ts): @@ -476,7 +491,7 @@ def verify_larger(self, ts): class TestNonTreeViterbiDip(VitAlgorithmBase): - """Test that we have the same log-likelihood across all implementations""" + """Test that the computed log-likelihoods are the same across all implementations.""" def verify(self, ts): for n, m, G_vs, s, e_vs, r in self.example_parameters_genotypes(ts):