diff --git a/tests/lsbase.py b/tests/lsbase.py index 1eae6e8..e9acef1 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -67,20 +67,27 @@ def get_examples_haploid(self, ts, include_ancestors): ref_panel = self.get_ancestral_haplotypes(ts) else: ref_panel = ts.genotype_matrix() - num_sites = ref_panel.shape[0] # Take some haplotypes as queries from the reference panel. - query1 = ref_panel[:, 0].reshape(1, num_sites) - query2 = ref_panel[:, 1].reshape(1, num_sites) + num_sites = ref_panel.shape[0] + query_1 = ref_panel[:, 0].reshape(1, num_sites) + query_1 = np.append(query_1[:2], query_1[2:]).reshape(1, num_sites) + query_2 = query_1[::-1] # Create queries with MISSING. - query_miss_last = query1.copy() + query_miss_last = query_1.copy() query_miss_last[0, -1] = core.MISSING - query_miss_mid = query1.copy() + query_miss_mid = query_1.copy() query_miss_mid[0, ts.num_sites // 2] = core.MISSING - query_miss_most = query1.copy() - query_miss_most[0, 2:] = core.MISSING - queries = [query1, query2, query_miss_last, query_miss_mid, query_miss_most] - # Exclude the arbitrarily chosen queries from the reference panel. - ref_panel = ref_panel[:, 2:] + query_miss_most_1 = query_1.copy() + query_miss_most_1[0, 2:] = core.MISSING + query_miss_most_2 = query_miss_most_1[::-1] + queries = [ + query_1, + query_2, + query_miss_last, + query_miss_mid, + query_miss_most_1, + query_miss_most_2, + ] return ref_panel, queries def get_examples_diploid(self, ts, include_ancestors): @@ -88,24 +95,29 @@ def get_examples_diploid(self, ts, include_ancestors): ref_panel = self.get_ancestral_haplotypes(ts) else: ref_panel = ts.genotype_matrix() - num_sites = ref_panel.shape[0] # Take some haplotypes as queries from the reference panel. - query_1 = np.zeros((2, num_sites), dtype=np.int32) - np.inf + num_sites = ref_panel.shape[0] + query_1 = np.zeros((2, num_sites), dtype=np.int8) - np.inf query_1[0, :] = ref_panel[:, 0].reshape(1, num_sites) query_1[1, :] = ref_panel[:, 1].reshape(1, num_sites) - query_2 = np.zeros((2, num_sites), dtype=np.int32) - np.inf - query_2[0, :] = ref_panel[:, 2].reshape(1, num_sites) - query_2[1, :] = ref_panel[:, 3].reshape(1, num_sites) + query_1 = np.append(query_1[:, :2], query_1[:, 2:]).reshape(2, num_sites) + query_2 = query_1[:, ::-1] # Create queries with MISSING. query_miss_last = query_1.copy() query_miss_last[:, -1] = core.MISSING query_miss_mid = query_1.copy() query_miss_mid[:, ts.num_sites // 2] = core.MISSING - query_miss_most = query_1.copy() - query_miss_most[:, 1:] = core.MISSING - queries = [query_1, query_2, query_miss_last, query_miss_mid, query_miss_most] - # Exclude the arbitrarily chosen queries from the reference panel. - # ref_panel = ref_panel[:, 4:] + query_miss_most_1 = query_1.copy() + query_miss_most_1[:, 2:] = core.MISSING + query_miss_most_2 = query_miss_most_1[:, ::-1] + queries = [ + query_1, + query_2, + query_miss_last, + query_miss_mid, + query_miss_most_1, + query_miss_most_2, + ] return ref_panel, queries def get_examples_pars(