diff --git a/tests/lsbase.py b/tests/lsbase.py index 1eae6e8..3e26387 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -88,24 +88,22 @@ def get_examples_diploid(self, ts, include_ancestors): ref_panel = self.get_ancestral_haplotypes(ts) else: ref_panel = ts.genotype_matrix() - num_sites = ref_panel.shape[0] # Take some haplotypes as queries from the reference panel. - query_1 = np.zeros((2, num_sites), dtype=np.int32) - np.inf + num_sites = ref_panel.shape[0] + query_1 = np.zeros((2, num_sites), dtype=np.int8) - np.inf query_1[0, :] = ref_panel[:, 0].reshape(1, num_sites) query_1[1, :] = ref_panel[:, 1].reshape(1, num_sites) - query_2 = np.zeros((2, num_sites), dtype=np.int32) - np.inf - query_2[0, :] = ref_panel[:, 2].reshape(1, num_sites) - query_2[1, :] = ref_panel[:, 3].reshape(1, num_sites) + query_1 = np.append(query_1[:, :2], query_1[:, 2:]).reshape(2, num_sites) + query_2 = query_1[:, ::-1] # Create queries with MISSING. query_miss_last = query_1.copy() query_miss_last[:, -1] = core.MISSING query_miss_mid = query_1.copy() query_miss_mid[:, ts.num_sites // 2] = core.MISSING - query_miss_most = query_1.copy() - query_miss_most[:, 1:] = core.MISSING - queries = [query_1, query_2, query_miss_last, query_miss_mid, query_miss_most] - # Exclude the arbitrarily chosen queries from the reference panel. - # ref_panel = ref_panel[:, 4:] + query_miss_most_1 = query_1.copy() + query_miss_most_1[:, 2:] = core.MISSING + query_miss_most_2 = query_miss_most_1[:, ::-1] + queries = [query_1, query_2, query_miss_last, query_miss_mid, query_miss_most_1, query_miss_most_2] return ref_panel, queries def get_examples_pars(