Skip to content

Commit

Permalink
Major refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 19, 2024
1 parent 8cb1910 commit 24e1240
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 87 deletions.
147 changes: 64 additions & 83 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,27 @@ def get_examples_haploid(self, ts):
H = ts.genotype_matrix()
s = H[:, 0].reshape(1, H.shape[0])
H = H[:, 1:]

haplotypes = [s, H[:, -1].reshape(1, H.shape[0])]
s_tmp = s.copy()
s_tmp[0, -1] = core.MISSING
haplotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, ts.num_sites // 2] = core.MISSING
haplotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, :] = core.MISSING
haplotypes.append(s_tmp)

haplotypes = [
s,
H[:, -1].reshape(1, H.shape[0])
]
s_miss_last = s.copy()
s_miss_last[0, -1] = core.MISSING
s_miss_mid = s.copy()
s_miss_mid[0, ts.num_sites // 2] = core.MISSING
s_miss_all = s.copy()
s_miss_all[0, :] = core.MISSING
haplotypes.append(s_miss_last)
haplotypes.append(s_miss_mid)
haplotypes.append(s_miss_all)
return H, haplotypes

def get_emission_prob_matrix_haploid(
def get_emission_matrix_haploid(
self, mu, m, n_alleles, scale_mutation_based_on_n_alleles
):
e = np.zeros((m, 2))
if isinstance(mu, float):
mu = mu * np.ones(m)

if scale_mutation_based_on_n_alleles:
e[:, 0] = mu - mu * np.equal(
n_alleles, np.ones(m)
Expand All @@ -70,38 +70,32 @@ def get_emission_prob_matrix_haploid(
e[j, 1] = 1 - mu[j]
return e

def get_examples_parameters_haploid(self, ts, scale_mutation=True, seed=42):
"""
Returns an iterator over combinations of haplotypes, recombination probabilties,
and mutation probabilities.
"""
def get_examples_pars_haploid(self, ts, scale_mutation=True, seed=42):
"""Returns an iterator over combinations of examples and parameters."""
np.random.seed(seed)
H, haplotypes = self.get_examples_haploid(ts)
n = H.shape[1]
m = ts.num_sites

# Here we have equal mutation and recombination
r = np.zeros(m) + 0.01
mu = np.zeros(m) + 0.01
r[0] = 0

for s in haplotypes:
# Must be calculated from the genotype matrix because we can now get back mutations that
# result in the number of alleles being higher than the number of alleles in the reference panel.
n_alleles = self.get_num_alleles(H, s)
e = self.get_emission_prob_matrix_haploid(
mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation
)
yield n, m, H, s, e, r, mu

# Mixture of random and extremes
rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)]
mus = [np.zeros(m) + 0.2, np.zeros(m) + 1e-6, np.random.rand(m) * 0.2]

n = H.shape[1]
rs = [
np.zeros(m) + 0.01, # Equal recombination and mutation
np.zeros(m) + 0.999, # Extreme
np.zeros(m) + 1e-6, # Extreme
np.random.rand(m), # Random
]
mus = [
np.zeros(m) + 0.01, # Equal recombination and mutation
np.zeros(m) + 0.2, # Extreme
np.zeros(m) + 1e-6, # Extreme
np.random.rand(m) * 0.2, # Random
]
for s, r, mu in itertools.product(haplotypes, rs, mus):
r[0] = 0
# Must be calculated from the genotype matrix,
# because we can now get back mutations that
# result in the number of alleles being higher
# than the number of alleles in the reference panel.
n_alleles = self.get_num_alleles(H, s)
e = self.get_emission_prob_matrix_haploid(
e = self.get_emission_matrix_haploid(
mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation
)
yield n, m, H, s, e, r, mu
Expand All @@ -112,32 +106,27 @@ def get_examples_diploid(self, ts, seed=42):
H = ts.genotype_matrix()
s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0])
H = H[:, 2:]

genotypes = [
s,
H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]),
H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0])
]

s_tmp = s.copy()
s_tmp[0, -1] = core.MISSING
genotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, ts.num_sites // 2] = core.MISSING
genotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, :] = core.MISSING
genotypes.append(s_tmp)

m = ts.get_num_sites()
s_miss_last = s.copy()
s_miss_last[0, -1] = core.MISSING
s_miss_mid = s.copy()
s_miss_mid[0, ts.num_sites // 2] = core.MISSING
s_miss_all = s.copy()
s_miss_all[0, :] = core.MISSING
genotypes.append(s_miss_last)
genotypes.append(s_miss_mid)
genotypes.append(s_miss_all)
m = ts.num_sites
n = H.shape[1]

G = np.zeros((m, n, n))
for i in range(m):
G[i, :, :] = np.add.outer(H[i, :], H[i, :])

return H, G, genotypes

def get_emission_prob_matrix_diploid(self, mu, m):
def get_emission_matrix_diploid(self, mu, m):
e = np.zeros((m, 8))
e[:, core.EQUAL_BOTH_HOM] = (1 - mu) ** 2
e[:, core.UNEQUAL_BOTH_HOM] = mu**2
Expand All @@ -147,47 +136,39 @@ def get_emission_prob_matrix_diploid(self, mu, m):
e[:, core.MISSING_INDEX] = 1
return e

def get_examples_parameters_diploid(self, ts, seed=42):
def get_examples_pars_diploid(self, ts, seed=42):
np.random.seed(seed)
H, G, genotypes = self.get_examples_diploid(ts)
n = H.shape[1]
m = ts.num_sites

# Here we have equal mutation and recombination
r = np.zeros(m) + 0.01
mu = np.zeros(m) + 0.01
r[0] = 0

e = self.get_emission_prob_matrix_diploid(mu, m)

for s in genotypes:
yield n, m, G, s, e, r, mu

# Mixture of random and extremes
rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)]
mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33]

n = H.shape[1]
rs = [
np.zeros(m) + 0.01, # Equal recombination and mutation
np.zeros(m) + 0.999, # Extreme
np.zeros(m) + 1e-6, # Extreme
np.random.rand(m), # Random
]
mus = [
np.zeros(m) + 0.01, # Equal recombination and mutation
np.zeros(m) + 0.33, # Extreme
np.zeros(m) + 1e-6, # Extreme
np.random.rand(m) * 0.33, # Random
]
for s, r, mu in itertools.product(genotypes, rs, mus):
r[0] = 0
e = self.get_emission_prob_matrix_diploid(mu, m)
e = self.get_emission_matrix_diploid(mu, m)
yield n, m, G, s, e, r, mu

def get_examples_parameters_larger_diploid(
def get_examples_pars_larger_diploid(
self, ts, mean_r=1e-5, mean_mu=1e-5, seed=42
):
np.random.seed(seed)
H, G, genotypes = self.get_examples_diploid(ts)

m = ts.get_num_sites()
m = H.shape[0]
n = H.shape[1]

r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2)
r[0] = 0

mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2)

e = self.get_emission_prob_matrix_diploid(mu, m)

e = self.get_emission_matrix_diploid(mu, m)
for s in genotypes:
yield n, m, G, s, e, r, mu

Expand Down
8 changes: 4 additions & 4 deletions tests/test_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_simple_n16(self):
self.verify(ts)

def verify(self, ts):
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_parameters_haploid(ts):
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars_haploid(ts):
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r)
B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r)
F, c, ll = ls.forwards(H_vs, s, r, p_mutation=mu)
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_simple_n16(self):
self.verify(ts)

def verify(self, ts):
for n, m, G_vs, s, e_vs, r, mu in self.get_examples_parameters_diploid(ts):
for n, m, G_vs, s, e_vs, r, mu in self.get_examples_pars_diploid(ts):
F_vs, c_vs, ll_vs = fbd.forward_ls_dip_loop(
n, m, G_vs, s, e_vs, r, norm=True
)
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_simple_n16(self):
self.verify(ts)

def verify(self, ts):
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_parameters_haploid(ts):
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars_haploid(ts):
V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling(
n, m, H_vs, s, e_vs, r
)
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_simple_n_16(self):
self.verify(ts)

def verify(self, ts):
for n, m, G_vs, s, e_vs, r, mu in self.get_examples_parameters_diploid(ts):
for n, m, G_vs, s, e_vs, r, mu in self.get_examples_pars_diploid(ts):
V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(n, m, G_vs, s, e_vs, r)
path_vs = vd.backwards_viterbi_dip(m, V_vs, P_vs)
phased_path_vs = vd.get_phased_path(n, path_vs)
Expand Down

0 comments on commit 24e1240

Please sign in to comment.