From 1601a5e09b68191d062e52c1a0dcec093149e1c1 Mon Sep 17 00:00:00 2001 From: lkirk Date: Mon, 16 Oct 2023 01:35:29 -0500 Subject: [PATCH] Implement a method for specifying row/col sites This allows a user to specify an array of row sites to be compared against an array of column sites. Rectangular matricies are returned as a result of these comparisons. We weighed the pros/cons of the indexing performance and redundant computation. The initial implementation avoided redundant computation, but at the expense of complexity and readability. The final implementation is much more simple and DOES perform redundant computations if the row/col sites cross the diagonal of the LD matrix. We will revisit if/when this becomes a problem. In addition to these changes, I've cleaned up the results matrix allocation. Since the python code will be allocating and managing this memory, we'll pull out the allocation code. C tests have been added to capture all of this new functionality. A type annotated python prototype that mirrors the C functionality was added to provide some documentation on how the C algorithms work, and to aid in the testing of the C code once the python API is ready. --- c/tests/test_stats.c | 387 ++++++++++----- c/tskit/trees.c | 335 +++++++------ c/tskit/trees.h | 54 +-- python/tests/test_ld_matrix.py | 833 +++++++++++++++++++++++++++++++++ 4 files changed, 1319 insertions(+), 290 deletions(-) create mode 100644 python/tests/test_ld_matrix.py diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 14decb18ff..214d473d41 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -2036,22 +2036,28 @@ static void test_paper_ex_two_site(void) { tsk_treeseq_t ts; - double *result; - tsk_size_t s, result_size; + double result[27]; + tsk_size_t s, result_size, num_sample_sets; int ret; - double truth_one_set[6] = { 1, 0.1111111111111111, 0.1111111111111111, 1, 1, 1 }; - double truth_two_sets[12] = { 1, 1, 0.1111111111111111, 0.1111111111111111, - 0.1111111111111111, 0.1111111111111111, 1, 1, 1, 1, 1, 1 }; - double truth_three_sets[18] = { 1, 1, 0, 0.1111111111111111, 0.1111111111111111, 0, - 0.1111111111111111, 0.1111111111111111, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; + double truth_one_set[9] = { 1, 0.1111111111111111, 0.1111111111111111, + 0.1111111111111111, 1, 1, 0.1111111111111111, 1, 1 }; + double truth_two_sets[18] = { 1, 1, 0.1111111111111111, 0.1111111111111111, + 0.1111111111111111, 0.1111111111111111, 0.1111111111111111, 0.1111111111111111, + 1, 1, 1, 1, 0.1111111111111111, 0.1111111111111111, 1, 1, 1, 1 }; + double truth_three_sets[27] + = { 1, 1, 0, 0.1111111111111111, 0.1111111111111111, 0, 0.1111111111111111, + 0.1111111111111111, 0, 0.1111111111111111, 0.1111111111111111, 0, 1, 1, 1, + 1, 1, 1, 0.1111111111111111, 0.1111111111111111, 0, 1, 1, 1, 1, 1, 1 }; tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); tsk_size_t sample_set_sizes[3]; - tsk_size_t num_sample_sets; tsk_id_t sample_sets[ts.num_samples * 3]; + tsk_size_t num_sites = ts.tables->sites.num_rows; + tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); + tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); // First sample set contains all of the samples sample_set_sizes[0] = ts.num_samples; @@ -2059,14 +2065,18 @@ test_paper_ex_two_site(void) for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; } + for (s = 0; s < num_sites; s++) { + row_sites[s] = (tsk_id_t) s; + col_sites[s] = (tsk_id_t) s; + } - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + result_size = num_sites * num_sites; + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 6); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_one_set); - tsk_safe_free(result); // Second sample set contains all of the samples sample_set_sizes[1] = ts.num_samples; @@ -2075,13 +2085,12 @@ test_paper_ex_two_site(void) sample_sets[s] = (tsk_id_t) s - (tsk_id_t) ts.num_samples; } - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 6); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_two_sets); - tsk_safe_free(result); // Third sample set contains the first two samples sample_set_sizes[2] = 2; @@ -2090,15 +2099,16 @@ test_paper_ex_two_site(void) sample_sets[s] = (tsk_id_t) s - (tsk_id_t) ts.num_samples * 2; } - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 6); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_three_sets); - tsk_safe_free(result); tsk_treeseq_free(&ts); + tsk_safe_free(row_sites); + tsk_safe_free(col_sites); } static void @@ -2145,83 +2155,86 @@ test_two_site_correlated_multiallelic(void) int ret; tsk_treeseq_t ts; - double *result; tsk_size_t s, result_size; - double truth_D[3] - = { 0.043209876543209874, -0.018518518518518517, 0.05555555555555555 }; - double truth_D2[3] - = { 0.023844603634269844, 0.02384460363426984, 0.02384460363426984 }; - double truth_r2[3] = { 1, 1, 1 }; - double truth_D_prime[3] - = { 0.7777777777777777, 0.4444444444444444, 0.6666666666666666 }; - double truth_r[3] - = { 0.18377223398316206, -0.12212786219416509, 0.2609542781331212 }; - double truth_Dz[3] - = { 0.0033870175616860566, 0.003387017561686057, 0.003387017561686057 }; - double truth_pi2[3] - = { 0.04579247743399549, 0.04579247743399549, 0.0457924774339955 }; + double truth_D[4] = { 0.043209876543209874, -0.018518518518518517, + -0.018518518518518517, 0.05555555555555555 }; + double truth_D2[4] = { 0.023844603634269844, 0.02384460363426984, + 0.02384460363426984, 0.02384460363426984 }; + double truth_r2[4] = { 1, 1, 1, 1 }; + double truth_D_prime[4] = { 0.7777777777777777, 0.4444444444444444, + 0.4444444444444444, 0.6666666666666666 }; + double truth_r[4] = { 0.18377223398316206, -0.12212786219416509, + -0.12212786219416509, 0.2609542781331212 }; + double truth_Dz[4] = { 0.0033870175616860566, 0.003387017561686057, + 0.003387017561686057, 0.003387017561686057 }; + double truth_pi2[4] = { 0.04579247743399549, 0.04579247743399549, + 0.04579247743399549, 0.0457924774339955 }; tsk_treeseq_from_text(&ts, 20, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); - tsk_size_t sample_set_sizes[1] = { ts.num_samples }; tsk_size_t num_sample_sets = 1; + tsk_size_t sample_set_sizes[1] = { ts.num_samples }; tsk_id_t sample_sets[ts.num_samples]; + tsk_size_t num_sites = ts.tables->sites.num_rows; + tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); + tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); + result_size = num_sites * num_sites; + double result[result_size]; for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; } + for (s = 0; s < num_sites; s++) { + row_sites[s] = (tsk_id_t) s; + col_sites[s] = (tsk_id_t) s; + } - ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_D); - tsk_safe_free(result); - ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_D2); - tsk_safe_free(result); - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_r2); - tsk_safe_free(result); - ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, - NULL, 0, NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_D_prime); - tsk_safe_free(result); - ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_r); - tsk_safe_free(result); - ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_Dz); - tsk_safe_free(result); - ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, - 0, NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_pi2); - tsk_safe_free(result); tsk_treeseq_free(&ts); + tsk_safe_free(row_sites); + tsk_safe_free(col_sites); } static void @@ -2278,78 +2291,81 @@ test_two_site_uncorrelated_multiallelic(void) tsk_treeseq_t ts; int ret; - double *result; - tsk_size_t result_size; - - double truth_D[3] = { 0.05555555555555555, 0.0, 0.05555555555555555 }; - double truth_D2[3] = { 0.024691358024691357, 0.0, 0.024691358024691357 }; - double truth_r2[3] = { 1, 0, 1 }; - double truth_D_prime[3] = { 0.6666666666666665, 0.0, 0.6666666666666665 }; - double truth_r[3] = { 0.24999999999999997, 0.0, 0.24999999999999997 }; - double truth_Dz[3] = { 0.0, 0.0, 0.0 }; - double truth_pi2[3] - = { 0.04938271604938272, 0.04938271604938272, 0.04938271604938272 }; + + double truth_D[4] = { 0.05555555555555555, 0.0, 0.0, 0.05555555555555555 }; + double truth_D2[4] = { 0.024691358024691357, 0.0, 0.0, 0.024691358024691357 }; + double truth_r2[4] = { 1, 0, 0, 1 }; + double truth_D_prime[4] = { 0.6666666666666665, 0.0, 0.0, 0.6666666666666665 }; + double truth_r[4] = { 0.24999999999999997, 0.0, 0.0, 0.24999999999999997 }; + double truth_Dz[4] = { 0.0, 0.0, 0.0, 0.0 }; + double truth_pi2[4] = { 0.04938271604938272, 0.04938271604938272, + 0.04938271604938272, 0.04938271604938272 }; tsk_treeseq_from_text(&ts, 20, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); - tsk_size_t sample_set_sizes[1] = { ts.num_samples }; + tsk_size_t s; tsk_size_t num_sample_sets = 1; + tsk_size_t num_sites = ts.tables->sites.num_rows; + tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); + tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); + tsk_size_t sample_set_sizes[1] = { ts.num_samples }; tsk_id_t sample_sets[ts.num_samples]; + tsk_size_t result_size = num_sites * num_sites; + double result[result_size]; - for (tsk_size_t s = 0; s < ts.num_samples; s++) { + for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; } + for (s = 0; s < num_sites; s++) { + row_sites[s] = (tsk_id_t) s; + col_sites[s] = (tsk_id_t) s; + } - ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_D); - tsk_safe_free(result); - ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_D2); - tsk_safe_free(result); - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_r2); - tsk_safe_free(result); - ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, - NULL, 0, NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_D_prime); - tsk_safe_free(result); - ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_r); - tsk_safe_free(result); - ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_Dz); - tsk_safe_free(result); - ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, - 0, NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); assert_arrays_almost_equal(result_size, result, truth_pi2); - tsk_safe_free(result); tsk_treeseq_free(&ts); + tsk_safe_free(row_sites); + tsk_safe_free(col_sites); } static void @@ -2386,44 +2402,122 @@ test_two_site_backmutation(void) "1 58 A 4\n"; int ret; - double *result; - tsk_size_t result_size; tsk_treeseq_t ts; tsk_treeseq_from_text(&ts, 10, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); - tsk_size_t sample_set_sizes[1] = { ts.num_samples }; tsk_size_t num_sample_sets = 1; + tsk_size_t num_sites = ts.tables->sites.num_rows; + tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); + tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); + tsk_size_t sample_set_sizes[1] = { ts.num_samples }; tsk_id_t sample_sets[ts.num_samples]; + tsk_size_t result_size = num_sites * num_sites; + double result[result_size]; + tsk_size_t s; - for (tsk_size_t s = 0; s < ts.num_samples; s++) { + double truth_r2[4] = { 0.999999999999999, 0.042923862278701, 0.042923862278701, 1. }; + + for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; } + for (s = 0; s < num_sites; s++) { + row_sites[s] = (tsk_id_t) s; + col_sites[s] = (tsk_id_t) s; + } - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(result_size, 3); - /* assert_arrays_almost_equal(result_size, result, truth_r2); */ - tsk_safe_free(result); + assert_arrays_almost_equal(result_size, result, truth_r2); tsk_treeseq_free(&ts); + tsk_safe_free(row_sites); + tsk_safe_free(col_sites); } static void -test_two_locus_stat_input_errors(void) +test_paper_ex_two_site_subset(void) { tsk_treeseq_t ts; - double *result; + double result[4]; + int ret; tsk_size_t s, result_size; + tsk_size_t sample_set_sizes[1]; + tsk_size_t num_sample_sets; + tsk_id_t row_sites[2] = { 0, 1 }; + tsk_id_t col_sites[2] = { 1, 2 }; + double result_truth_1[4] = { 0.1111111111111111, 0.1111111111111111, 1, 1 }; + double result_truth_2[1] = { 0.1111111111111111 }; + double result_truth_3[4] = { 0.1111111111111111, 1, 0.1111111111111111, 1 }; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + tsk_id_t sample_sets[ts.num_samples]; + + sample_set_sizes[0] = ts.num_samples; + num_sample_sets = 1; + for (s = 0; s < ts.num_samples; s++) { + sample_sets[s] = (tsk_id_t) s; + } + + result_size = 2 * 2; + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, + row_sites, 2, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_sample_sets, result, result_truth_1); + + result_size = 1 * 1; + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + col_sites[0] = 2; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 1, + row_sites, 1, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_sample_sets, result, result_truth_2); + + result_size = 2 * 2; + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + row_sites[0] = 1; + row_sites[1] = 2; + col_sites[0] = 0; + col_sites[1] = 1; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, + row_sites, 2, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_sample_sets, result, result_truth_3); + + tsk_treeseq_free(&ts); +} + +static void +test_two_locus_stat_input_errors(void) +{ + tsk_treeseq_t ts; int ret; tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, single_tree_ex_sites, single_tree_ex_mutations, NULL, NULL, 0); - tsk_size_t sample_set_sizes[1]; - tsk_size_t num_sample_sets; + tsk_size_t num_sites = ts.tables->sites.num_rows; + tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); + tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); + tsk_size_t sample_set_sizes[1] = { ts.num_samples }; + tsk_size_t num_sample_sets = 1; tsk_id_t sample_sets[ts.num_samples]; + tsk_size_t result_size = num_sites * num_sites; + double result[result_size]; + tsk_size_t s; + + for (s = 0; s < ts.num_samples; s++) { + sample_sets[s] = (tsk_id_t) s; + } + for (s = 0; s < num_sites; s++) { + row_sites[s] = (tsk_id_t) s; + col_sites[s] = (tsk_id_t) s; + } sample_set_sizes[0] = ts.num_samples; num_sample_sets = 1; @@ -2432,36 +2526,70 @@ test_two_locus_stat_input_errors(void) } sample_sets[1] = 0; - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); sample_sets[1] = 1; - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, TSK_STAT_SITE | TSK_STAT_BRANCH, &result_size, &result); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, TSK_STAT_SITE | TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_STAT_MODES); - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, TSK_STAT_BRANCH, &result_size, &result); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); - ret = tsk_treeseq_r2(&ts, 0, sample_set_sizes, sample_sets, 0, NULL, 0, NULL, 0, - &result_size, &result); + ret = tsk_treeseq_r2(&ts, 0, sample_set_sizes, sample_sets, num_sites, row_sites, + num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_STATE_DIMS); sample_set_sizes[0] = 0; - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EMPTY_SAMPLE_SET); sample_set_sizes[0] = ts.num_samples; sample_sets[1] = 10; - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, &result_size, &result); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); sample_sets[1] = 1; + row_sites[0] = 1000; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); + row_sites[0] = 0; + + col_sites[num_sites - 1] = (tsk_id_t) num_sites; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); + col_sites[num_sites - 1] = (tsk_id_t) num_sites - 1; + + row_sites[0] = 1; + row_sites[1] = 0; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSORTED_SITES); + row_sites[0] = 0; + row_sites[1] = 1; + + row_sites[0] = 1; + row_sites[1] = 1; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, num_sites, col_sites, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSORTED_SITES); + row_sites[0] = 0; + row_sites[1] = 1; + + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, + NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SITE_POSITION); + tsk_treeseq_free(&ts); + tsk_safe_free(row_sites); + tsk_safe_free(col_sites); } static void @@ -2744,6 +2872,7 @@ main(int argc, char **argv) { "test_two_site_uncorrelated_multiallelic", test_two_site_uncorrelated_multiallelic }, { "test_two_site_backmutation", test_two_site_backmutation }, + { "test_paper_ex_two_site_subset", test_paper_ex_two_site_subset }, { "test_two_locus_stat_input_errors", test_two_locus_stat_input_errors }, { "test_simplest_divergence_matrix", test_simplest_divergence_matrix }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 56b2661c12..61fbf686d2 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2355,18 +2355,63 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, return ret; } +static void +get_site_row_col_indices(tsk_size_t n_rows, const tsk_id_t *row_sites, tsk_size_t n_cols, + const tsk_id_t *col_sites, tsk_id_t *sites, tsk_size_t *n_sites, tsk_size_t *row_idx, + tsk_size_t *col_idx) +{ + tsk_size_t r = 0, c = 0, s = 0; + + // Iterate rows and columns until we've exhaused one of the lists + while ((r < n_rows) && (c < n_cols)) { + if (row_sites[r] < col_sites[c]) { + sites[s] = row_sites[r]; + row_idx[r] = s; + s++; + r++; + } else if (col_sites[c] < row_sites[r]) { + sites[s] = col_sites[c]; + col_idx[c] = s; + s++; + c++; + } else { // row == col + sites[s] = row_sites[r]; + col_idx[c] = s; + row_idx[r] = s; + s++; + r++; + c++; + } + } + + // If there are any items remaining in the other list, drain it + while (r < n_rows) { + sites[s] = row_sites[r]; + row_idx[r] = s; + s++; + r++; + } + while (c < n_cols) { + sites[s] = col_sites[c]; + col_idx[c] = s; + s++; + c++; + } + *n_sites = s; +} + static int -get_mutation_samples( - const tsk_treeseq_t *ts, tsk_size_t *num_alleles, tsk_bit_array_t *allele_samples) +get_mutation_samples(const tsk_treeseq_t *ts, const tsk_id_t *sites, tsk_size_t n_sites, + tsk_size_t *num_alleles, tsk_bit_array_t *allele_samples) { int ret = 0; const tsk_flags_t *restrict flags = ts->tables->nodes.flags; const tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); const tsk_size_t *restrict site_muts_len = ts->site_mutations_length; - const tsk_site_t *restrict site; + tsk_site_t site; tsk_tree_t tree; tsk_bit_array_t all_samples_bits, mut_samples, mut_samples_row, out_row; - tsk_size_t max_muts_len, mut_offset, num_nodes, s, m, n; + tsk_size_t max_muts_len, site_offset, num_nodes, site_idx, s, m, n; tsk_id_t node, *nodes = NULL; void *tmp_nodes; @@ -2374,11 +2419,12 @@ get_mutation_samples( tsk_memset(&all_samples_bits, 0, sizeof(all_samples_bits)); max_muts_len = 0; - for (s = 0; s < ts->tables->sites.num_rows; s++) { - if (site_muts_len[s] > max_muts_len) { - max_muts_len = site_muts_len[s]; + for (s = 0; s < n_sites; s++) { + if (site_muts_len[sites[s]] > max_muts_len) { + max_muts_len = site_muts_len[sites[s]]; } } + // Allocate a bit array of size max alleles for all sites ret = tsk_bit_array_init(&mut_samples, num_samples, max_muts_len); if (ret != 0) { goto out; @@ -2387,103 +2433,111 @@ get_mutation_samples( if (ret != 0) { goto out; } - + get_all_samples_bits(&all_samples_bits, num_samples); ret = tsk_tree_init(&tree, ts, TSK_NO_SAMPLE_COUNTS); if (ret != 0) { goto out; } - // A future improvement could get a union of all sample sets - // instead of all samples - get_all_samples_bits(&all_samples_bits, num_samples); - - // Traverse down each tree, recording all samples below each mutation. We perform one - // preorder traversal per mutation. - mut_offset = 0; - for (ret = tsk_tree_first(&tree); ret == TSK_TREE_OK; ret = tsk_tree_next(&tree)) { + // For each mutation within each site, perform one preorder traversal to gather + // the samples under each mutation's node. + site_offset = 0; + for (site_idx = 0; site_idx < n_sites; site_idx++) { + tsk_treeseq_get_site(ts, sites[site_idx], &site); + ret = tsk_tree_seek(&tree, site.position, 0); + if (ret != 0) { + goto out; + } tmp_nodes = tsk_realloc(nodes, tsk_tree_get_size_bound(&tree) * sizeof(*nodes)); if (tmp_nodes == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } nodes = tmp_nodes; - for (s = 0; s < tree.sites_length; s++) { - site = &tree.sites[s]; - tsk_bit_array_get_row(allele_samples, mut_offset, &out_row); - tsk_bit_array_add(&out_row, &all_samples_bits); - // Zero out results before the start of each iteration - tsk_memset(mut_samples.data, 0, - mut_samples.size * max_muts_len * sizeof(tsk_bit_array_value_t)); - for (m = 0; m < site->mutations_length; m++) { - tsk_bit_array_get_row(&mut_samples, m, &mut_samples_row); - node = site->mutations[m].node; - ret = tsk_tree_preorder_from(&tree, node, nodes, &num_nodes); - if (ret != 0) { - goto out; - } - for (n = 0; n < num_nodes; n++) { - node = nodes[n]; - if (flags[node] & TSK_NODE_IS_SAMPLE) { - tsk_bit_array_add_bit( - &mut_samples_row, (tsk_bit_array_value_t) node); - } + + tsk_bit_array_get_row(allele_samples, site_offset, &out_row); + tsk_bit_array_add(&out_row, &all_samples_bits); + + // Zero out results before the start of each iteration + tsk_memset(mut_samples.data, 0, + mut_samples.size * max_muts_len * sizeof(tsk_bit_array_value_t)); + for (m = 0; m < site.mutations_length; m++) { + tsk_bit_array_get_row(&mut_samples, m, &mut_samples_row); + node = site.mutations[m].node; + ret = tsk_tree_preorder_from(&tree, node, nodes, &num_nodes); + if (ret != 0) { + goto out; + } + for (n = 0; n < num_nodes; n++) { + node = nodes[n]; + if (flags[node] & TSK_NODE_IS_SAMPLE) { + tsk_bit_array_add_bit( + &mut_samples_row, (tsk_bit_array_value_t) node); } - mut_offset++; } - mut_offset++; // One more for the ancestral allele - get_allele_samples(site, &mut_samples, &out_row, &(num_alleles[site->id])); } + site_offset += site.mutations_length + 1; + get_allele_samples(&site, &mut_samples, &out_row, &(num_alleles[site_idx])); } - // if adding code below, check ret before continuing +// if adding code below, check ret before continuing out: tsk_safe_free(nodes); tsk_tree_free(&tree); tsk_bit_array_free(&mut_samples); tsk_bit_array_free(&all_samples_bits); - return ret; + return ret == TSK_TREE_OK ? 0 : ret; } static int tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *norm_f, - const double *TSK_UNUSED(left_window), const double *TSK_UNUSED(right_window), - tsk_flags_t options, tsk_size_t *result_size, double **result) + sample_count_stat_params_t *f_params, norm_func_t *norm_f, tsk_size_t n_rows, + const tsk_id_t *row_sites, tsk_size_t n_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { + int ret = 0; - tsk_bit_array_t allele_samples; - tsk_bit_array_t site_a_state, site_b_state; - tsk_size_t inner, result_offset, inner_offset, a_offset, b_offset; - tsk_size_t site_a, site_b; + tsk_bit_array_t allele_samples, c_state, r_state; bool polarised = false; - const tsk_size_t num_sites = self->tables->sites.num_rows; + tsk_id_t *sites; + tsk_size_t r, c, s, n_alleles, n_sites, *row_idx, *col_idx; + double *result_row; const tsk_size_t num_samples = self->num_samples; - const tsk_size_t max_alleles = self->tables->mutations.num_rows + num_sites; - tsk_size_t *num_alleles = tsk_malloc(num_sites * sizeof(*num_alleles)); - const tsk_size_t *restrict site_muts_len = self->site_mutations_length; + tsk_size_t *num_alleles = NULL, *site_offsets = NULL; + tsk_size_t result_row_len = n_cols * result_dim; tsk_memset(&allele_samples, 0, sizeof(allele_samples)); - if (num_alleles == NULL) { + sites = tsk_malloc(self->tables->sites.num_rows * sizeof(*sites)); + row_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*row_idx)); + col_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*col_idx)); + if (sites == NULL || row_idx == NULL || col_idx == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } + get_site_row_col_indices( + n_rows, row_sites, n_cols, col_sites, sites, &n_sites, row_idx, col_idx); - ret = tsk_bit_array_init(&allele_samples, num_samples, max_alleles); - if (ret != 0) { + // We rely on n_sites to allocate these arrays, they're initialized to NULL for safe + // deallocation if the previous allocation fails + num_alleles = tsk_malloc(n_sites * sizeof(*num_alleles)); + site_offsets = tsk_malloc(n_sites * sizeof(*site_offsets)); + if (num_alleles == NULL || site_offsets == NULL) { + ret = TSK_ERR_NO_MEMORY; goto out; } - ret = get_mutation_samples(self, num_alleles, &allele_samples); + + n_alleles = 0; + for (s = 0; s < n_sites; s++) { + site_offsets[s] = n_alleles; + n_alleles += self->site_mutations_length[sites[s]] + 1; + } + ret = tsk_bit_array_init(&allele_samples, num_samples, n_alleles); if (ret != 0) { goto out; } - - // Number of pairs w/ replacement (sites) - *result_size = (num_sites * (1 + num_sites)) / 2U; - *result = tsk_calloc(*result_size * result_dim, sizeof(**result)); - - if (result == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = get_mutation_samples(self, sites, n_sites, num_alleles, &allele_samples); + if (ret != 0) { goto out; } @@ -2491,34 +2545,28 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, polarised = true; } - inner = 0; - a_offset = 0; - b_offset = 0; - inner_offset = 0; - result_offset = 0; - // TODO: implement windows! - for (site_a = 0; site_a < num_sites; site_a++) { - b_offset = inner_offset; - for (site_b = inner; site_b < num_sites; site_b++) { - tsk_bit_array_get_row(&allele_samples, a_offset, &site_a_state); - tsk_bit_array_get_row(&allele_samples, b_offset, &site_b_state); - ret = compute_general_two_site_stat_result(&site_a_state, &site_b_state, - num_alleles[site_a], num_alleles[site_b], num_samples, state_dim, + // For each row/column pair, fill in the sample set in the result matrix. + for (r = 0; r < n_rows; r++) { + result_row = GET_2D_ROW(result, result_row_len, r); + for (c = 0; c < n_cols; c++) { + tsk_bit_array_get_row(&allele_samples, site_offsets[row_idx[r]], &r_state); + tsk_bit_array_get_row(&allele_samples, site_offsets[col_idx[c]], &c_state); + ret = compute_general_two_site_stat_result(&r_state, &c_state, + num_alleles[row_idx[r]], num_alleles[col_idx[c]], num_samples, state_dim, sample_sets, result_dim, f, f_params, norm_f, polarised, - &((*result)[result_offset])); + &(result_row[c * result_dim])); if (ret != 0) { goto out; } - result_offset += result_dim; - b_offset += site_muts_len[site_b] + 1; } - a_offset += site_muts_len[site_a] + 1; - inner_offset += site_muts_len[site_a] + 1; - inner++; } out: + tsk_safe_free(sites); + tsk_safe_free(row_idx); + tsk_safe_free(col_idx); tsk_safe_free(num_alleles); + tsk_safe_free(site_offsets); tsk_bit_array_free(&allele_samples); return ret; } @@ -2558,14 +2606,43 @@ sample_sets_to_bit_array(const tsk_treeseq_t *self, const tsk_size_t *sample_set return ret; } +static int +check_sites(const tsk_id_t *sites, tsk_size_t num_sites, tsk_size_t num_site_rows) +{ + int ret = 0; + tsk_size_t i; + + if (sites == NULL || num_sites == 0) { + ret = TSK_ERR_BAD_SITE_POSITION; // TODO: error should be no sites? + goto out; + } + + for (i = 0; i < num_sites - 1; i++) { + if (sites[i] < 0 || sites[i] >= (tsk_id_t) num_site_rows) { + ret = TSK_ERR_SITE_OUT_OF_BOUNDS; + goto out; + } + if (sites[i] >= sites[i + 1]) { + // TODO: this checks no repeats, but error is ambiguous + ret = TSK_ERR_UNSORTED_SITES; + goto out; + } + } + // check the last value + if (sites[i] < 0 || sites[i] >= (tsk_id_t) num_site_rows) { + ret = TSK_ERR_SITE_OUT_OF_BOUNDS; + goto out; + } +out: + return ret; +} + static int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, - norm_func_t *norm_f, tsk_size_t TSK_UNUSED(num_left_windows), - const double *left_windows, tsk_size_t TSK_UNUSED(num_right_windows), - const double *right_windows, tsk_flags_t options, tsk_size_t *result_size, - double **result) + norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, + tsk_size_t out_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result) { // TODO: generalize this function if we ever decide to do weighted two_locus stats. // We only implement count stats and therefore we don't handle weights. @@ -2601,8 +2678,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl // goto out; // } - tsk_bug_assert(left_windows == NULL && right_windows == NULL); - ret = tsk_treeseq_check_sample_sets( self, num_sample_sets, sample_set_sizes, sample_sets); if (ret != 0) { @@ -2615,9 +2690,17 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl } if (stat_site) { + ret = check_sites(row_sites, out_rows, self->tables->sites.num_rows); + if (ret != 0) { + goto out; + } + ret = check_sites(col_sites, out_cols, self->tables->sites.num_rows); + if (ret != 0) { + goto out; + } ret = tsk_treeseq_two_site_count_stat(self, state_dim, &sample_sets_bits, - result_dim, f, &f_params, norm_f, left_windows, right_windows, options, - result_size, result); + result_dim, f, &f_params, norm_f, out_rows, row_sites, out_cols, col_sites, + options, result); } else { ret = TSK_ERR_UNSUPPORTED_STAT_MODE; } @@ -3451,16 +3534,14 @@ D_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_D(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { options |= TSK_STAT_POLARISED; // TODO: allow user to pick? return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_summary_func, norm_total_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + num_rows, row_sites, num_cols, col_sites, options, result); } static int @@ -3490,15 +3571,13 @@ D2_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_D2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D2_summary_func, norm_total_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + num_rows, row_sites, num_cols, col_sites, options, result); } static int @@ -3534,15 +3613,13 @@ r2_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_r2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, - sample_sets, num_sample_sets, NULL, r2_summary_func, norm_hap_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + sample_sets, num_sample_sets, NULL, r2_summary_func, norm_hap_weighted, num_rows, + row_sites, num_cols, col_sites, options, result); } static int @@ -3576,16 +3653,14 @@ D_prime_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_D_prime(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { options |= TSK_STAT_POLARISED; // TODO: allow user to pick? return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_prime_summary_func, norm_hap_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + num_rows, row_sites, num_cols, col_sites, options, result); } static int @@ -3621,16 +3696,14 @@ r_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_r(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { options |= TSK_STAT_POLARISED; // TODO: allow user to pick? return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, r_summary_func, norm_total_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + num_rows, row_sites, num_cols, col_sites, options, result); } static int @@ -3661,15 +3734,13 @@ Dz_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_Dz(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, Dz_summary_func, norm_total_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + num_rows, row_sites, num_cols, col_sites, options, result); } static int @@ -3697,15 +3768,13 @@ pi2_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_pi2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result) + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, pi2_summary_func, norm_total_weighted, - num_left_windows, left_windows, num_right_windows, right_windows, options, - result_size, result); + num_rows, row_sites, num_cols, col_sites, options, result); } /*********************************** diff --git a/c/tskit/trees.h b/c/tskit/trees.h index dbc870ad2f..2faa3c95c6 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1046,41 +1046,39 @@ int tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +typedef int two_locus_count_stat_method(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites, + tsk_size_t num_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result); + int tsk_treeseq_D(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); int tsk_treeseq_D2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); int tsk_treeseq_r2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); int tsk_treeseq_D_prime(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); int tsk_treeseq_r(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); int tsk_treeseq_Dz(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); int tsk_treeseq_pi2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_left_windows, const double *left_windows, - tsk_size_t num_right_windows, const double *right_windows, tsk_flags_t options, - tsk_size_t *result_size, double **result); + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result); /* Three way sample set stats */ int tsk_treeseq_Y3(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py new file mode 100644 index 0000000000..40ccd8f480 --- /dev/null +++ b/python/tests/test_ld_matrix.py @@ -0,0 +1,833 @@ +import io +from itertools import combinations_with_replacement +from itertools import permutations +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generator +from typing import List +from typing import Tuple + +import numpy as np +import pytest + +import tskit + + +class BitSet: + """BitSet object, which stores values in arrays of unsigned integers. + The rows represent all possible values a bit can take, and the rows + represent each item that can be stored in the array. + + :param num_bits: The number of values that a single row can contain. + :param length: The number of rows. + """ + + DTYPE = np.uint32 # Data type to be stored in the bitset + CHUNK_SIZE = DTYPE(32) # Size of integer field to store the data in + + def __init__(self: "BitSet", num_bits: int, length: int) -> None: + self.row_len = num_bits // self.CHUNK_SIZE + self.row_len += 1 if num_bits % self.CHUNK_SIZE else 0 + self.row_len = int(self.row_len) + self.data = np.zeros(self.row_len * length, dtype=self.DTYPE) + + def intersect( + self: "BitSet", self_row: int, other: "BitSet", other_row: int, out: "BitSet" + ) -> None: + """Intersect a row from the current array instance with a row from + another BitSet and store it in an output bit array of length 1. + + NB: we don't specify the row in the output array, it is expected + to be length 1. + + :param self_row: Row from the current array instance to be intersected. + :param other: Other BitSet to intersect with. + :param other_row: Row from the other BitSet instance. + :param out: BitArray to store the result. + """ + self_offset = self_row * self.row_len + other_offset = other_row * self.row_len + + for i in range(self.row_len): + out.data[i] = self.data[i + self_offset] & other.data[i + other_offset] + + def difference( + self: "BitSet", self_row: int, other: "BitSet", other_row: int + ) -> None: + """Take the difference between the current array instance and another + array instance. Store the result in the specified row of the current + instance. + + :param self_row: Row from the current array from which to subtract. + :param other: Other BitSet to subtract from the current instance. + :param other_row: Row from the other BitSet instance. + """ + self_offset = self_row * self.row_len + other_offset = other_row * self.row_len + + for i in range(self.row_len): + self.data[i + self_offset] &= ~(other.data[i + other_offset]) + + def union(self: "BitSet", self_row: int, other: "BitSet", other_row: int) -> None: + """Take the union between the current array instance and another + array instance. Store the result in the specified row of the current + instance. + + :param self_row: Row from the current array with which to union. + :param other: Other BitSet to union with the current instance. + :param other_row: Row from the other BitSet instance. + """ + self_offset = self_row * self.row_len + other_offset = other_row * self.row_len + + for i in range(self.row_len): + self.data[i + self_offset] |= other.data[i + other_offset] + + def add(self: "BitSet", row: int, bit: int) -> None: + """Add a single bit to the row of a bit array + + :param row: Row to be modified. + :param bit: Bit to be added. + """ + offset = row * self.row_len + i = bit // self.CHUNK_SIZE + self.data[i + offset] |= self.DTYPE(1) << (bit - (self.CHUNK_SIZE * i)) + + def get_items(self: "BitSet", row: int) -> Generator[int, None, None]: + """Get the items stored in the row of a bitset + + :param row: Row from the array to list from. + :returns: A generator of integers stored in the array. + """ + offset = row * self.row_len + for i in range(self.row_len): + for item in range(self.CHUNK_SIZE): + if self.data[i + offset] & (self.DTYPE(1) << item): + yield item + (i * self.CHUNK_SIZE) + + def contains(self: "BitSet", row: int, bit: int) -> bool: + """Test if a bit is contained within a bit array row + + :param row: Row to test. + :param bit: Bit to check. + :returns: True if the bit is set in the row, else false. + """ + i = bit // self.CHUNK_SIZE + offset = row * self.row_len + return bool( + self.data[i + offset] & (self.DTYPE(1) << (bit - (self.CHUNK_SIZE * i))) + ) + + def count(self: "BitSet", row: int) -> int: + """Count all of the set bits in a specified row. Uses a SWAR + algorithm to count in parallel with a constant number (12) of operations. + + NB: we have to cast all values to our unsigned dtype to avoid type promotion + + Details here: + # https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel + + :param row: Row to count. + :returns: Count of all of the set bits. + """ + count = 0 + offset = row * self.row_len + D = self.DTYPE + + for i in range(offset, offset + self.row_len): + v = self.data[i] + v = v - ((v >> D(1)) & D(0x55555555)) + v = (v & D(0x33333333)) + ((v >> D(2)) & D(0x33333333)) + # this operation relies on integer overflow + with np.errstate(over="ignore"): + count += ((v + (v >> D(4)) & D(0xF0F0F0F)) * D(0x1010101)) >> D(24) + + return count + + def count_naive(self: "BitSet", row: int) -> int: + """Naive counting algorithm implementing the same functionality as the count + method. Useful for testing correctness, uses the same number of operations + as set bits. + + Details here: + # https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetNaive + + :param row: Row to count. + :returns: Count of all of the set bits. + """ + count = 0 + offset = row * self.row_len + + for i in range(offset, offset + self.row_len): + v = self.data[i] + while v: + v &= v - self.DTYPE(1) + count += self.DTYPE(1) + return count + + +def norm_hap_weighted( + state_dim: int, + hap_weights: np.ndarray, + n_a: int, + n_b: int, + result: np.ndarray, + params: Dict[str, Any], +) -> None: + """Create a vector of normalizing coefficients, length of the number of + sample sets. In this normalization strategy, we weight each allele's + statistic by the proportion of the haplotype present. + + :param state_dim: Number of sample sets. + :param hap_weights: Proportion of each two-locus haplotype. + :param n_a: Number of alleles at the A locus. + :param n_b: Number of alleles at the B locus. + :param result: Result vector to store the normalizing coefficients in. + :param params: Params of summary function. + """ + del n_a, n_b # handle unused params + sample_set_sizes = params["sample_set_sizes"] + for k in range(state_dim): + n = sample_set_sizes[k] + result[k] = hap_weights[0, k] / n + + +def norm_total_weighted( + state_dim: int, + hap_weights: np.ndarray, + n_a: int, + n_b: int, + result: np.ndarray, + params: Dict[str, Any], +) -> None: + """Create a vector of normalizing coefficients, length of the number of + sample sets. In this normalization strategy, we weight each allele's + statistic by the product of the allele frequencies + + :param state_dim: Number of sample sets. + :param hap_weights: Proportion of each two-locus haplotype. + :param n_a: Number of alleles at the A locus. + :param n_b: Number of alleles at the B locus. + :param result: Result vector to store the normalizing coefficients in. + :param params: Params of summary function. + """ + del hap_weights, params # handle unused params + for k in range(state_dim): + result[k] = 1 / (n_a * n_b) + + +def check_sites(sites, max_sites): + """Validate the specified site ids. + + We require that sites are: + + 1) Within the boundaries of available sites in the tree sequence + 2) Sorted + 3) Non-repeating + + Raises an exception if any error is found. + + :param sites: 1d array of sites to validate. + :param max_sites: Number of sites in the tree sequence, the upper + bound value for site ids. + """ + if sites is None or len(sites) == 0: + raise ValueError("No sites provided") + i = 0 + for i in range(len(sites) - 1): + if sites[i] < 0 or sites[i] >= max_sites: + raise ValueError(f"Site out of bounds: {sites[i]}") + if sites[i] >= sites[i + 1]: + raise ValueError(f"Sites not sorted: {sites[i], sites[i + 1]}") + if sites[-1] < 0 or sites[-1] >= max_sites: + raise ValueError(f"Site out of bounds: {sites[i + 1]}") + + +def get_site_row_col_indices( + row_sites: List[int], col_sites: List[int] +) -> Tuple[List[int], List[int], List[int]]: + """Co-iterate over the row and column sites, keeping a sorted union of + site values and an index into the unique list of sites for both the row + and column sites. This function produces a list of sites of interest and + row and column indexes into this list of sites. + + NB: This routine requires that the site lists are sorted and deduplicated. + + :param row_sites: List of sites that will be represented in the output + matrix rows. + :param col_sites: List of sites that will be represented in the output + matrix columns. + :returns: Tuple of lists of sites, row, and column indices. + """ + r = 0 + c = 0 + s = 0 + sites = [] + col_idx = [] + row_idx = [] + + while r < len(row_sites) and c < len(col_sites): + if row_sites[r] < col_sites[c]: + sites.append(row_sites[r]) + row_idx.append(s) + s += 1 + r += 1 + elif row_sites[r] > col_sites[c]: + sites.append(col_sites[c]) + col_idx.append(s) + s += 1 + c += 1 + else: + sites.append(row_sites[r]) + row_idx.append(s) + col_idx.append(s) + s += 1 + r += 1 + c += 1 + while r < len(row_sites): + sites.append(row_sites[r]) + row_idx.append(s) + s += 1 + r += 1 + while c < len(col_sites): + sites.append(col_sites[c]) + col_idx.append(s) + s += 1 + c += 1 + + return sites, row_idx, col_idx + + +def get_all_samples_bits(num_samples: int) -> BitSet: + """Get the bits for all samples in the tree sequence. This is achieved + by creating a length 1 bitset and adding every sample's bit to it. + + :param num_samples: Number of samples contained in the tree sequence. + :returns: Length 1 BitSet containing all samples in the tree sequence. + """ + all_samples = BitSet(num_samples, 1) + for i in range(num_samples): + all_samples.add(0, i) + return all_samples + + +def get_allele_samples( + site: tskit.Site, site_offset: int, mut_samples: BitSet, allele_samples: BitSet +) -> int: + """Given a BitSet that has been arranged so that we have every sample under + a given mutation's node, create the final output where we know which samples + should belong under each mutation, considering the mutation's parentage, + back mutations, and ancestral state. + + To this end, we iterate over each mutation and store the samples under the + focal mutation in the output BitSet (allele_samples). Then, we check the + parent of the focal mutation (either a mutation or the ancestral allele), + and we subtract the samples in the focal mutation from the parent allele's + samples. + + :param site: Focal site for which to adjust mutation data. + :param site_offset: Offset into allele_samples for our focal site. + :param mut_samples: BitSet containing the samples under each mutation in the + focal site. + :param allele_samples: Output BitSet, initially passed in with all of the + tree sequence samples set in the ancestral allele + state. + :returns: number of alleles actually encountered (adjusting for back-mutation). + """ + alleles = [] + num_alleles = 1 + alleles.append(site.ancestral_state) + + for m, mut in enumerate(site.mutations): + try: + allele = alleles.index(mut.derived_state) + except ValueError: + allele = len(alleles) + alleles.append(mut.derived_state) + num_alleles += 1 + allele_samples.union(allele + site_offset, mut_samples, m) + # now to find the parent allele from which we must subtract + alt_allele_state = site.ancestral_state + if mut.parent != tskit.NULL: + parent_mut = site.mutations[mut.parent - site.mutations[0].id] + alt_allele_state = parent_mut.derived_state + alt_allele = alleles.index(alt_allele_state) + # subtract focal allele's samples from the alt allele + allele_samples.difference( + alt_allele + site_offset, allele_samples, allele + site_offset + ) + + return num_alleles + + +def get_mutation_samples( + ts: tskit.TreeSequence, sites: List[int] +) -> Tuple[np.ndarray, np.ndarray, BitSet]: + """For a given set of sites, generate a BitSet of all samples posessing + each allelic state for each site. This includes the ancestral state, along + with any mutations contained in the site. + + We achieve this goal by starting at the tree containing the first site in + our list, then we walk along each tree until we've encountered the last + tree containing the last site in our list. Along the way, we perform a + preorder traversal from the node of each mutation in a given site, storing + the samples under that particular node. After we've stored all of the samples + for each allele at a site, we adjust each allele's samples by removing + samples that have a different allele at a child mutation down the tree (see + get_allele_samples for more details). + + We also gather some ancillary data while we iterate over the sites: the + number of alleles for each site, and the offset of each site. The number of + alleles at each site includes the count of mutations + the ancestral allele. + The offeset for each site indicates how many array entries we must skip (ie + how many alleles exist before a specific site's entry) in order to address + the data for a given site. + + :param ts: Tree sequence to gather data from. + :param sites: Subset of sites to consider when gathering data. + :returns: Tuple of the number of alleles per site, site offsets, and the + BitSet of all samples in each allelic state. + """ + num_alleles = np.zeros(len(sites), dtype=np.uint64) + site_offsets = np.zeros(len(sites), dtype=np.uint64) + all_samples = get_all_samples_bits(ts.num_samples) + allele_samples = BitSet( + ts.num_samples, sum(len(ts.site(i).mutations) + 1 for i in sites) + ) + + site_offset = 0 + site_idx = 0 + for site_idx, site_id in enumerate(sites): + site = ts.site(site_id) + tree = ts.at(site.position) + # initialize the ancestral allele with all samples + allele_samples.union(site_offset, all_samples, 0) + # store samples for each mutation in mut_samples + mut_samples = BitSet(ts.num_samples, len(site.mutations)) + for m, mut in enumerate(site.mutations): + for node in tree.preorder(mut.node): + if ts.node(node).is_sample(): + mut_samples.add(m, node) + # account for mutation parentage, subtract samples from mutation parents + num_alleles[site_idx] = get_allele_samples( + site, site_offset, mut_samples, allele_samples + ) + # increment the offset for ancestral + mutation alleles + site_offsets[site_idx] = site_offset + site_offset += len(site.mutations) + 1 + + return num_alleles, site_offsets, allele_samples + + +def compute_general_two_site_stat_result( + row_site_offset: int, + col_site_offset: int, + num_row_alleles: int, + num_col_alleles: int, + num_samples: int, + allele_samples: BitSet, + state_dim: int, + sample_sets: BitSet, + func: Callable[[int, np.ndarray, np.ndarray, Dict[str, Any]], None], + norm_func: Callable[[int, np.ndarray, int, int, np.ndarray, Dict[str, Any]], None], + params: Dict[str, Any], + polarised: bool, + result: np.ndarray, +) -> None: + """For a given pair of sites, compute the summary statistic for the allele + frequencies for each allelic state of the two pairs. + + :param row_site_offset: Offset of the row site's data in the allele_samples. + :param row_site_offset: Offset of the col site's data in the allele_samples. + :param num_row_alleles: Number of alleles in the row site. + :param num_col_alleles: Number of alleles in the col site. + :param num_samples: Number of samples in tree sequence. + :param allele_samples: BitSet containing the samples with each allelic state + for each site of interest. + :param state_dim: Number of sample sets. + :param sample_sets: BitSet of sample sets to be intersected with the samples + contained within each allele. + :param func: Summary function used to compute each two-locus statistic. + :param norm_func: Function used to generate the normalization coefficients + for each statistic. + :param params: Parameters to pass to the norm and summary function. + :param polarised: If true, skip the computation of the statistic for the + ancestral state. + :param result: Vector of the results matrix to populate. We will produce one + value per sample set, hence the vector of length state_dim. + """ + ss_A_samples = BitSet(num_samples, 1) + ss_B_samples = BitSet(num_samples, 1) + ss_AB_samples = BitSet(num_samples, 1) + AB_samples = BitSet(num_samples, 1) + weights = np.zeros((3, state_dim), np.float64) + norm = np.zeros(state_dim, np.float64) + result_tmp = np.zeros(state_dim, np.float64) + + polarised_val = 1 if polarised else 0 + + for mut_a in range(polarised_val, num_row_alleles): + a = int(mut_a + row_site_offset) + for mut_b in range(polarised_val, num_col_alleles): + b = int(mut_b + col_site_offset) + allele_samples.intersect(a, allele_samples, b, AB_samples) + for k in range(state_dim): + allele_samples.intersect(a, sample_sets, k, ss_A_samples) + allele_samples.intersect(b, sample_sets, k, ss_B_samples) + AB_samples.intersect(0, sample_sets, k, ss_AB_samples) + + w_AB = ss_AB_samples.count(0) + w_A = ss_A_samples.count(0) + w_B = ss_B_samples.count(0) + + weights[0, k] = w_AB + weights[1, k] = w_A - w_AB # w_Ab + weights[2, k] = w_B - w_AB # w_aB + + func(state_dim, weights, result_tmp, params) + + norm_func( + state_dim, + weights, + num_row_alleles - polarised_val, + num_col_alleles - polarised_val, + norm, + params, + ) + + for k in range(state_dim): + result[k] += result_tmp[k] * norm[k] + + +def two_site_count_stat( + ts: tskit.TreeSequence, + func: Callable[[int, np.ndarray, np.ndarray, Dict[str, Any]], None], + norm_func: Callable[[int, np.ndarray, int, int, np.ndarray, Dict[str, Any]], None], + num_sample_sets: int, + sample_set_sizes: np.ndarray, + sample_sets: BitSet, + row_sites: List[int], + col_sites: List[int], + polarised: bool, +) -> np.ndarray: + """Outer function that generates the high-level intermediates used in the + computation of our two-locus statistics. First, we compute the row and + column indices for our unique list of sites, then we get each sample for + each allele in our list of specified sites. + + With those intermediates in hand, we iterate over the row and column indices + to compute comparisons between each of the specified lists of sites. We pass + a vector of results to the computation, which will compute a single result + for each sample set, inserting that into our result matrix. + + :param ts: Tree sequence to gather data from. + :param func: Function used to compute each two-locus statistic. + :param norm_func: Function used to generate the normalization coefficients + for each statistic. + :param num_sample_sets: Number of sample sets that we will consider. + :param sample_set_sizes: Number of samples in each sample set. + :param sample_sets: BitSet of samples to compute stats for. We will only + consider these samples in our computations, resulting + in stats that are computed on subsets of the samples + on the tree sequence. + :param row_sites: Sites contained in the rows of the output matrix. + :param col_sites: Sites contained in the columns of the output matrix. + :param polarised: If true, skip the computation of the statistic for the + ancestral state. + :returns: 3D array of results, dimensions (sample_sets, row_sites, col_sites). + """ + state_dim = len(sample_set_sizes) + params = {"sample_set_sizes": sample_set_sizes} + result = np.zeros( + (num_sample_sets, len(row_sites), len(col_sites)), dtype=np.float64 + ) + + sites, row_idx, col_idx = get_site_row_col_indices(row_sites, col_sites) + num_alleles, site_offsets, allele_samples = get_mutation_samples(ts, sites) + + for row, row_site in enumerate(row_idx): + for col, col_site in enumerate(col_idx): + compute_general_two_site_stat_result( + site_offsets[row_site], + site_offsets[col_site], + num_alleles[row_site], + num_alleles[col_site], + ts.num_samples, + allele_samples, + state_dim, + sample_sets, + func, + norm_func, + params, + polarised, + result[:, row, col], + ) + + return result + + +def sample_sets_to_bit_array( + ts: tskit.TreeSequence, sample_sets: List[List[int]] +) -> Tuple[np.ndarray, BitSet]: + """Convert the list of sample ids to a bit array. This function takes + sample identifiers and maps them to their enumerated integer values, then + stores these values in a bit array. We produce a BitArray and a numpy + array of integers that specify how many samples there are in each sample set. + + NB: this function's type signature is of type integer, but I believe this + could be expanded to Any, currently untested so the integer + specification remains. + + :param ts: Tree sequence to gather data from. + :param sample_sets: List of sample identifiers to store in bit array. + :returns: Tuple containing numpy array of sample set sizes and the sample + set BitSet. + """ + sample_sets_bits = BitSet(ts.num_samples, len(sample_sets)) + sample_index_map = -np.ones(ts.num_nodes, dtype=np.int32) + sample_set_sizes = np.zeros(len(sample_sets), dtype=np.uint64) + + for i, sample in enumerate(ts.samples()): + sample_index_map[sample] = i + + for k, sample_set in enumerate(sample_sets): + sample_set_sizes[k] = len(sample_set) + for sample in sample_set: + sample_index = sample_index_map[sample] + if sample_index == tskit.NULL: + raise ValueError(f"Sample out of bounds: {sample}") + if sample_sets_bits.contains(k, sample_index): + raise ValueError(f"Duplicate sample detected: {sample}") + sample_sets_bits.add(k, sample_index) + + return sample_set_sizes, sample_sets_bits + + +def two_locus_count_stat( + ts, + summary_func, + norm_func, + polarised, + sites=None, + sample_sets=None, +): + """Outer wrapper for two site general stat functionality. Perform some input + validation, get the site index and allele state, then compute the LD matrix. + + TODO: implement mode switching for branch stats + + :param ts: Tree sequence to gather data from. + :param summary_func: Function used to compute each two-locus statistic. + :param norm_func: Function used to generate the normalization coefficients + for each statistic. + :param polarised: If true, skip the computation of the statistic for the + ancestral state. + :param sites: List of two lists containing [row_sites, column_sites]. + :param sample_sets: List of lists of samples to compute stats for. We will + only consider these samples in our computations, + resulting in stats that are computed on subsets of the + samples on the tree sequence. + :returns: 3d numpy array containing LD for (sample_set,row_site,column_site) + unless one or no sample sets are specified, then 2d array + containing LD for (row_site,column_site). + """ + if sample_sets is None: + sample_sets = [ts.samples()] + if sites is None: + sites = [np.arange(ts.num_sites), np.arange(ts.num_sites)] + else: + if len(sites) != 2: + raise ValueError( + f"Sites must be a length 2 list, got a length {len(sites)} list" + ) + sites[0] = np.asarray(sites[0]) + sites[1] = np.asarray(sites[1]) + + row_sites, col_sites = sites + check_sites(row_sites, ts.num_sites) + check_sites(col_sites, ts.num_sites) + + ss_sizes, ss_bits = sample_sets_to_bit_array(ts, sample_sets) + + result = two_site_count_stat( + ts, + summary_func, + norm_func, + len(ss_sizes), + ss_sizes, + ss_bits, + sites[0], + sites[1], + polarised, + ) + + # If there is one sample set, return a 2d numpy array of row/site LD + if len(sample_sets) == 1: + return result.reshape(result.shape[1:3]) + return result + + +def r2_summary_func( + state_dim: int, state: np.ndarray, result: np.ndarray, params: Dict[str, Any] +) -> None: + """Summary function for the r2 statistic. We first compute the proportion of + AB, A, and B haplotypes, then we compute the r2 statistic, storing the outputs + in the result vector, one entry per sample set. + + :param state_dim: Number of sample sets. + :param state: Counts of 3 haplotype configurations for each sample set. + :param result: Vector of length state_dim to store the results in. + :param params: Parameters for the summary function. + """ + sample_set_sizes = params["sample_set_sizes"] + for k in range(state_dim): + n = sample_set_sizes[k] + p_AB = state[0, k] / n + p_Ab = state[1, k] / n + p_aB = state[2, k] / n + + p_A = p_AB + p_Ab + p_B = p_AB + p_aB + + D = p_AB - (p_A * p_B) + denom = p_A * p_B * (1 - p_A) * (1 - p_B) + + if denom == 0 and D == 0: + result[k] = 0 + else: + result[k] = (D * D) / denom + + +def get_paper_ex_ts(): + """Generate the tree sequence example from the tskit paper + + Data taken from the tests: + https://github.com/tskit-dev/tskit/blob/61a844a/c/tests/testlib.c#L55-L96 + + :returns: Tree sequence + """ + nodes = """\ + is_sample time population individual + 1 0 -1 0 + 1 0 -1 0 + 1 0 -1 1 + 1 0 -1 1 + 0 0.071 -1 -1 + 0 0.090 -1 -1 + 0 0.170 -1 -1 + 0 0.202 -1 -1 + 0 0.253 -1 -1 + """ + + edges = """\ + left right parent child + 2 10 4 2 + 2 10 4 3 + 0 10 5 1 + 0 2 5 3 + 2 10 5 4 + 0 7 6 0,5 + 7 10 7 0,5 + 0 2 8 2,6 + """ + + sites = """\ + position ancestral_state + 1 0 + 4.5 0 + 8.5 0 + """ + + mutations = """\ + site node derived_state + 0 2 1 + 1 0 1 + 2 5 1 + """ + + individuals = """\ + flags location parents + 0 0.2,1.5 -1,-1 + 0 0.0,0.0 -1,-1 + """ + + return tskit.load_text( + nodes=io.StringIO(nodes), + edges=io.StringIO(edges), + sites=io.StringIO(sites), + individuals=io.StringIO(individuals), + mutations=io.StringIO(mutations), + strict=False, + ) + + +# fmt:off +# true r2 values for the tree sequence from the tskit paper +PAPER_EX_TRUTH_MATRIX = np.array( + [[1.0, 0.11111111, 0.11111111], # noqa: E241 + [0.11111111, 1.0, 1.0], # noqa: E241 + [0.11111111, 1.0, 1.0]] # noqa: E241 +) +# fmt:on + + +def get_all_site_partitions(n): + """Generate all partitions for square matricies, then combine with replacement + and return all possible pairs of all partitions. + + TODO: only works for square matricies, would need to generate two lists of + partitions to get around this + + :param n: length of one dimension of the !square! matrix. + :returns: combinations of partitions. + """ + parts = [] + for part in tskit.combinatorics.rule_asc(3): + for g in set(permutations(part, len(part))): + p = [] + i = iter(range(n)) + for item in g: + p.append([next(i) for _ in range(item)]) + parts.append(p) + combos = [] + for a, b in combinations_with_replacement({tuple(j) for i in parts for j in i}, 2): + combos.append((a, b)) + combos.append((b, a)) + combos = [[list(a), list(b)] for a, b in set(combos)] + return combos + + +def assert_slice_allclose(a, b): + """Provide two lists of sites to the general stat function, then check to + see if the subset matches the slice out of the truth matrix. Raise if + arrays not close. + + :param a: row sites. + :param b: column sites. + """ + ts = get_paper_ex_ts() + np.testing.assert_allclose( + two_locus_count_stat( + ts, r2_summary_func, norm_hap_weighted, False, sites=[a, b] + ), + PAPER_EX_TRUTH_MATRIX[a[0] : a[-1] + 1, b[0] : b[-1] + 1], + ) + + +@pytest.mark.parametrize( + # Generate all partitions of the LD matrix that, then pass into test_subset + "partition", + get_all_site_partitions(len(PAPER_EX_TRUTH_MATRIX)), +) +def test_subset(partition): + """Given a partition of the truth matrix, check that we can successfully + compute the LD matrix for that given partition, effectively ensuring that + our handling of site subsets is correct. + + :param partition: length 2 list of [row_sites, column_sites]. This is a + pytest fixture for a parametrized function. + """ + a, b = partition + print(a, b) + assert_slice_allclose(a, b)