diff --git a/python/tests/test_imputation.py b/python/tests/test_imputation.py index 2c2029044c..96135fe90f 100644 --- a/python/tests/test_imputation.py +++ b/python/tests/test_imputation.py @@ -656,7 +656,7 @@ def test_tsimpute(input_ref, input_query): # Tests for helper functions. @pytest.mark.parametrize( - "genotyped_pos,ungenotyped_pos,expected_pos,expected_idx", + "genotyped_pos,ungenotyped_pos,expected_weights,expected_idx", [ # All ungenotyped markers are between genotyped markers. ( @@ -690,12 +690,14 @@ def test_tsimpute(input_ref, input_query): (np.array([10, 20]), np.array([15]), np.array([0.001]), np.array([0])), ], ) -def test_get_weights(genotyped_pos, ungenotyped_pos, expected_pos, expected_idx): - actual_pos, actual_idx = tests.beagle.get_weights(genotyped_pos, ungenotyped_pos) - np.testing.assert_array_almost_equal(actual_pos, expected_pos) +def test_get_weights(genotyped_pos, ungenotyped_pos, expected_weights, expected_idx): + actual_weights, actual_idx = tests.beagle.get_weights( + genotyped_pos, ungenotyped_pos + ) + np.testing.assert_array_almost_equal(actual_weights, expected_weights) np.testing.assert_array_equal(actual_idx, expected_idx) - actual_idx, actual_idx = tests.beagle_numba.get_weights( + actual_weights, actual_idx = tests.beagle_numba.get_weights( genotyped_pos, ungenotyped_pos ) - np.testing.assert_array_almost_equal(actual_pos, expected_pos) + np.testing.assert_array_almost_equal(actual_weights, expected_weights) np.testing.assert_array_equal(actual_idx, expected_idx)