Skip to content

Commit

Permalink
Refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Sep 13, 2023
1 parent 36e63c6 commit 9822fd9
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions python/tests/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,38 +507,34 @@ def test_compare_allele_probabilities(input_ref, input_query, expected):
pos = (np.arange(9) + 1) * 1e4
num_query_haps = input_query.shape[0]
for i in np.arange(num_query_haps):
_, i_allele_probs = tests.beagle.run_beagle(
_, allele_probs = tests.beagle.run_beagle(
input_ref, input_query[i], pos, miscall_rate=0.0001, ne=10.0
)
# Rescale probabilities before comparison
rescaled_probs = np.zeros_like(i_allele_probs).T
for j in np.arange(rescaled_probs.shape[1]):
rescaled_probs[:, j] = i_allele_probs[j] / np.sum(i_allele_probs[j])
assert np.allclose(rescaled_probs, expected[i], atol=1e-04)
assert np.allclose(allele_probs.T, expected[i], atol=1e-04)


@pytest.mark.parametrize(
"input_ref,input_query,expected",
"input_ref,input_query",
[
(toy_ref_0, toy_query_0),
],
)
def test_beagle_numba(input_ref, input_query, expected):
def test_beagle_numba(input_ref, input_query):
pos = (np.arange(9) + 1) * 1e4
num_query_haps = input_query.shape[0]
for i in np.arange(num_query_haps):
imputed_alleles, i_allele_probs = tests.beagle.run_beagle(
imputed_alleles, allele_probs = tests.beagle.run_beagle(
input_ref, input_query[i], pos, miscall_rate=0.0001, ne=10.0
)
imputed_alleles_numba, i_allele_probs_numba = tests.beagle_numba.run_beagle(
imputed_alleles_numba, allele_probs_numba = tests.beagle_numba.run_beagle(
input_ref, input_query[i], pos, miscall_rate=0.0001, ne=10.0
)
assert np.array_equal(imputed_alleles, imputed_alleles_numba)
assert np.allclose(i_allele_probs, i_allele_probs_numba)
assert np.allclose(allele_probs, allele_probs_numba)


@pytest.mark.parametrize(
"input_ref,input_query,expected",
"input_ref,input_query",
[
(toy_ref_0, toy_query_0),
],
Expand All @@ -547,14 +543,14 @@ def test_tsimpute(input_ref, input_query):
pos = (np.arange(9) + 1) * 1e4
num_query_haps = input_query.shape[0]
for i in np.arange(num_query_haps):
imputed_alleles, i_allele_probs = tests.beagle.run_beagle(
imputed_alleles, allele_probs = tests.beagle.run_beagle(
input_ref, input_query[i], pos, miscall_rate=0.0001, ne=10.0
)
imputed_alleles_ts, i_allele_probs_ts = tests.beagle_numba.run_tsimpute(
imputed_alleles_ts, allele_probs_ts = tests.beagle_numba.run_tsimpute(
input_ref, input_query[i], pos
)
assert np.array_equal(imputed_alleles, imputed_alleles_ts)
assert np.allclose(i_allele_probs, i_allele_probs_ts)
assert np.allclose(allele_probs, allele_probs_ts)


# Below is toy data set case 7 in tree sequence format.
Expand Down

0 comments on commit 9822fd9

Please sign in to comment.