Skip to content

Commit

Permalink
Implement NONCOPY for diploid case
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 1, 2024
1 parent b8a99aa commit 5c1de0b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 60 deletions.
7 changes: 3 additions & 4 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def convert_haplotypes_to_phased_genotypes(ref_panel):
The only allowable allele states are 0, 1, and NONCOPY (for partial ancestral haplotypes).
TODO: Handle multiallelic sites.
TODO: Handle NONCOPY.
Allowable genotype values are 0, 1, 2, and NONCOPY. If either one haplotype is NONCOPY
at a site, then the genotype at the site is assigned NONCOPY.
Expand All @@ -75,7 +74,7 @@ def convert_haplotypes_to_phased_genotypes(ref_panel):
:return: An array of reference genotypes.
:rtype: numpy.ndarray
"""
ALLOWED_ALLELE_STATES = np.array([0, 1], dtype=np.int32)
ALLOWED_ALLELE_STATES = np.array([0, 1, NONCOPY], dtype=np.int32)
assert np.all(
np.isin(np.unique(ref_panel), ALLOWED_ALLELE_STATES)
), f"Reference haplotypes contain illegal allele states."
Expand All @@ -85,8 +84,8 @@ def convert_haplotypes_to_phased_genotypes(ref_panel):
for i in range(num_sites):
site_alleles = ref_panel[i, :]
genotypes[i, :, :] = np.add.outer(site_alleles, site_alleles)
# genotypes[i, site_alleles == NONCOPY, :] = NONCOPY
# genotypes[i, :, site_alleles == NONCOPY] = NONCOPY
genotypes[i, site_alleles == NONCOPY, :] = NONCOPY
genotypes[i, :, site_alleles == NONCOPY] = NONCOPY
return genotypes


Expand Down
8 changes: 4 additions & 4 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ def get_examples_haploid(self, ts, include_ancestors):
return ref_panel, queries

def get_examples_diploid(self, ts, include_ancestors):
# TODO Handle NONCOPY properly.
if include_ancestors is True:
raise NotImplementedError
ref_panel = ts.genotype_matrix()
num_sites = ref_panel.shape[0]
# Take some haplotypes as queries from the reference panel.
Expand All @@ -92,7 +89,10 @@ def get_examples_diploid(self, ts, include_ancestors):
query_miss_most[:, 1:] = core.MISSING
queries = [query_1, query_2, query_miss_last, query_miss_mid, query_miss_most]
# Exclude the arbitrarily chosen queries from the reference panel.
ref_panel = ref_panel[:, 2:-2]
if include_ancestors:
ref_panel = self.get_ancestral_haplotypes(ts)
else:
ref_panel = ref_panel[:, 2:-2]
return ref_panel, queries

def get_examples_pars(
Expand Down
24 changes: 12 additions & 12 deletions tests/test_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
self.assertAllClose(ll_vs, ll)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(
Expand All @@ -186,7 +186,7 @@ def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(
Expand All @@ -196,7 +196,7 @@ def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(
Expand All @@ -206,7 +206,7 @@ def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(
Expand All @@ -216,7 +216,7 @@ def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(
Expand All @@ -226,7 +226,7 @@ def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
Expand Down Expand Up @@ -367,7 +367,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
self.assertAllClose(phased_path_vs, path)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(
Expand All @@ -377,7 +377,7 @@ def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(
Expand All @@ -387,7 +387,7 @@ def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(
Expand All @@ -397,7 +397,7 @@ def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(
Expand All @@ -407,7 +407,7 @@ def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(
Expand All @@ -417,7 +417,7 @@ def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
Expand Down
87 changes: 47 additions & 40 deletions tests/test_non_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,11 @@ def verify(
):
G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs)
s = core.convert_haplotypes_to_unphased_genotypes(query)

F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True)
B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r)
self.assertAllClose(np.sum(F_vs * B_vs, (1, 2)), np.ones(m))

F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop(
n, m, G_vs, s, e_vs, r, norm=True
)
Expand Down Expand Up @@ -155,7 +157,7 @@ def verify(
self.assertAllClose(ll_vs, ll_tmp)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("normalise", [True, False])
def test_ts_simple_n10_no_recomb(
self, scale_mutation_rate, include_ancestors, normalise
Expand All @@ -173,7 +175,7 @@ def test_ts_simple_n10_no_recomb(
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("normalise", [True, False])
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors, normalise):
ts = self.get_ts_simple_n6()
Expand All @@ -187,7 +189,7 @@ def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors, normalise):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("normalise", [True, False])
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors, normalise):
ts = self.get_ts_simple_n8()
Expand All @@ -201,7 +203,7 @@ def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors, normalise):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("normalise", [True, False])
def test_ts_simple_n8_high_recomb(
self, scale_mutation_rate, include_ancestors, normalise
Expand All @@ -217,7 +219,7 @@ def test_ts_simple_n8_high_recomb(
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("normalise", [True, False])
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors, normalise):
ts = self.get_ts_simple_n16()
Expand All @@ -231,7 +233,7 @@ def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors, normalise):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("normalise", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors, normalise):
ts = self.get_ts_custom_pars(
Expand Down Expand Up @@ -395,30 +397,13 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
):
G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs)
s = core.convert_haplotypes_to_unphased_genotypes(query)
V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_naive(n, m, G_vs, s, e_vs, r)
path_vs = vd.backwards_viterbi_dip(m, V_vs[m - 1, :, :], P_vs)

V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(n, m, G_vs, s, e_vs, r)
path_vs = vd.backwards_viterbi_dip(m, V_vs, P_vs)
phased_path_vs = vd.get_phased_path(n, path_vs)
path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, s, e_vs, r)
self.assertAllClose(ll_vs, path_ll_vs)

V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem(
n, m, G_vs, s, e_vs, r
)
path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp)
phased_path_tmp = vd.get_phased_path(n, path_tmp)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, path_ll_tmp)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_low_mem(
n, m, G_vs, s, e_vs, r
)
path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp)
phased_path_tmp = vd.get_phased_path(n, path_tmp)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, path_ll_tmp)
self.assertAllClose(ll_vs, ll_tmp)

(
V_tmp,
V_argmaxes_tmp,
Expand All @@ -442,17 +427,39 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
self.assertAllClose(ll_tmp, path_ll_tmp)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec(
n, m, G_vs, s, e_vs, r
)
path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp)
phased_path_tmp = vd.get_phased_path(n, path_tmp)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, path_ll_tmp)
self.assertAllClose(ll_vs, ll_tmp)
MAX_NUM_REF_HAPS = 100
num_ref_haps = H_vs.shape[1]
if num_ref_haps <= MAX_NUM_REF_HAPS:
# Run tests for the naive implementations.
V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive(
n, m, G_vs, s, e_vs, r
)
path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp)
phased_path_tmp = vd.get_phased_path(n, path_tmp)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, path_ll_tmp)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem(
n, m, G_vs, s, e_vs, r
)
path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp)
phased_path_tmp = vd.get_phased_path(n, path_tmp)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, path_ll_tmp)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec(
n, m, G_vs, s, e_vs, r
)
path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp)
phased_path_tmp = vd.get_phased_path(n, path_tmp)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, path_ll_tmp)
self.assertAllClose(ll_vs, ll_tmp)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(
Expand All @@ -462,7 +469,7 @@ def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(
Expand All @@ -472,7 +479,7 @@ def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(
Expand All @@ -482,7 +489,7 @@ def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(
Expand All @@ -492,7 +499,7 @@ def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(
Expand All @@ -502,7 +509,7 @@ def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
Expand Down

0 comments on commit 5c1de0b

Please sign in to comment.