From 54556c5977c458a224f6228196a3e79347543f54 Mon Sep 17 00:00:00 2001 From: szhan Date: Mon, 24 Jun 2024 13:45:45 +0100 Subject: [PATCH] Update how MISSING is handled when updating probabilities --- python/tests/test_haplotype_matching.py | 43 +++++++++++++++---------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index 6b2ad23b0f..dd49f899ae 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -333,11 +333,10 @@ def update_probabilities(self, site, haplotype_state): while allelic_state[v] == -1: v = tree.parent(v) assert v != -1 - match = ( - haplotype_state == MISSING or haplotype_state == allelic_state[v] - ) + match = haplotype_state == allelic_state[v] + is_query_missing = haplotype_state == MISSING # Note that the node u is used only by Viterbi - st.value = self.compute_next_probability(site.id, st.value, match, u) + st.value = self.compute_next_probability(site.id, st.value, match, is_query_missing, u) # Unset the states allelic_state[tree.root] = -1 @@ -404,7 +403,7 @@ def run(self, h): def compute_normalisation_factor(self): raise NotImplementedError() - def compute_next_probability(self, site_id, p_last, is_match, node): + def compute_next_probability(self, site_id, p_last, is_match, is_query_missing, node): raise NotImplementedError() @@ -435,10 +434,13 @@ def compute_normalisation_factor(self): s += self.N[j] * st.value return s - def compute_next_probability(self, site_id, p_last, is_match, node): + def compute_next_probability(self, site_id, p_last, is_match, is_query_missing, node): rho = self.rho[site_id] n = self.ts.num_samples - p_e = self.compute_emission_proba(site_id, is_match) + if is_query_missing: + p_e = 1.0 + else: + p_e = self.compute_emission_proba(site_id, is_match) p_t = p_last * (1 - rho) + rho / n return p_t * p_e @@ -448,8 +450,11 @@ class BackwardAlgorithm(ForwardAlgorithm): The Li and Stephens backward algorithm. """ - def compute_next_probability(self, site_id, p_next, is_match, node): - p_e = self.compute_emission_proba(site_id, is_match) + def compute_next_probability(self, site_id, p_next, is_match, is_query_missing, node): + if is_query_missing: + p_e = 1.0 + else: + p_e = self.compute_emission_proba(site_id, is_match) return p_next * p_e def process_site(self, site, haplotype_state, s): @@ -515,7 +520,7 @@ def compute_normalisation_factor(self): ) return max_st.value - def compute_next_probability(self, site_id, p_last, is_match, node): + def compute_next_probability(self, site_id, p_last, is_match, is_query_missing, node): rho = self.rho[site_id] n = self.ts.num_samples @@ -529,7 +534,11 @@ def compute_next_probability(self, site_id, p_last, is_match, node): recombination_required = True self.output.add_recombination_required(site_id, node, recombination_required) - p_e = self.compute_emission_proba(site_id, is_match) + if is_query_missing: + p_e = 1.0 + else: + p_e = self.compute_emission_proba(site_id, is_match) + return p_t * p_e @@ -679,12 +688,12 @@ def traceback(self): def get_site_alleles(ts, h, alleles): if alleles is None: - n_alleles = np.int8( - [ - len(np.unique(np.append(ts.genotype_matrix()[j, :], h[j]))) - for j in range(ts.num_sites) - ] - ) + n_alleles = np.zeros(ts.num_sites, dtype=np.int8) - 1 + for j in range(ts.num_sites): + uniq_alleles = np.unique(np.append(ts.genotype_matrix()[j, :], h[j])) + uniq_alleles = uniq_alleles[uniq_alleles != MISSING] + n_alleles[j] = len(uniq_alleles) + assert np.all(n_alleles > 0) alleles = tskit.ALLELES_ACGT if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: alleles = tskit.ALLELES_01