diff --git a/tests/lsbase.py b/tests/lsbase.py index 3e26387..e9acef1 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -67,20 +67,27 @@ def get_examples_haploid(self, ts, include_ancestors): ref_panel = self.get_ancestral_haplotypes(ts) else: ref_panel = ts.genotype_matrix() - num_sites = ref_panel.shape[0] # Take some haplotypes as queries from the reference panel. - query1 = ref_panel[:, 0].reshape(1, num_sites) - query2 = ref_panel[:, 1].reshape(1, num_sites) + num_sites = ref_panel.shape[0] + query_1 = ref_panel[:, 0].reshape(1, num_sites) + query_1 = np.append(query_1[:2], query_1[2:]).reshape(1, num_sites) + query_2 = query_1[::-1] # Create queries with MISSING. - query_miss_last = query1.copy() + query_miss_last = query_1.copy() query_miss_last[0, -1] = core.MISSING - query_miss_mid = query1.copy() + query_miss_mid = query_1.copy() query_miss_mid[0, ts.num_sites // 2] = core.MISSING - query_miss_most = query1.copy() - query_miss_most[0, 2:] = core.MISSING - queries = [query1, query2, query_miss_last, query_miss_mid, query_miss_most] - # Exclude the arbitrarily chosen queries from the reference panel. - ref_panel = ref_panel[:, 2:] + query_miss_most_1 = query_1.copy() + query_miss_most_1[0, 2:] = core.MISSING + query_miss_most_2 = query_miss_most_1[::-1] + queries = [ + query_1, + query_2, + query_miss_last, + query_miss_mid, + query_miss_most_1, + query_miss_most_2, + ] return ref_panel, queries def get_examples_diploid(self, ts, include_ancestors): @@ -103,7 +110,14 @@ def get_examples_diploid(self, ts, include_ancestors): query_miss_most_1 = query_1.copy() query_miss_most_1[:, 2:] = core.MISSING query_miss_most_2 = query_miss_most_1[:, ::-1] - queries = [query_1, query_2, query_miss_last, query_miss_mid, query_miss_most_1, query_miss_most_2] + queries = [ + query_1, + query_2, + query_miss_last, + query_miss_mid, + query_miss_most_1, + query_miss_most_2, + ] return ref_panel, queries def get_examples_pars(