Skip to content

Commit

Permalink
Major refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 20, 2024
1 parent ec8ed22 commit 5789876
Show file tree
Hide file tree
Showing 2 changed files with 392 additions and 438 deletions.
113 changes: 63 additions & 50 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ class LSBase:
def verify(self, ts):
raise NotImplementedError()

def verify_larger(self, ts):
pass

def assertAllClose(self, A, B):
np.testing.assert_allclose(A, B, rtol=1e-9, atol=0.0)

Expand Down Expand Up @@ -70,24 +67,39 @@ def get_emission_matrix_haploid(
e[j, 1] = 1 - mu[j]
return e

def get_examples_pars_haploid(self, ts, scale_mutation=True, seed=42):
def get_examples_pars_haploid(
self,
ts,
mean_r=None,
mean_mu=None,
scale_mutation=True,
seed=42
):
"""Returns an iterator over combinations of examples and parameters."""
np.random.seed(seed)
H, haplotypes = self.get_examples_haploid(ts)
m = ts.num_sites
n = H.shape[1]
rs = [
np.zeros(m) + 0.01, # Equal recombination and mutation
np.zeros(m) + 0.999, # Extreme
np.zeros(m) + 1e-6, # Extreme
np.random.rand(m), # Random
]
mus = [
np.zeros(m) + 0.01, # Equal recombination and mutation
np.zeros(m) + 0.2, # Extreme
np.zeros(m) + 1e-6, # Extreme
np.random.rand(m) * 0.2, # Random
]
if mean_r is not None and mean_mu is not None:
rs = [
mean_r * (np.random.rand(m) + 0.5) / 2
]
mus = [
mean_mu * (np.random.rand(m) + 0.5) / 2
]
else:
rs = [
np.zeros(m) + 0.01, # Equal recombination and mutation
np.zeros(m) + 0.999, # Extreme
np.zeros(m) + 1e-6, # Extreme
np.random.rand(m), # Random
]
mus = [
np.zeros(m) + 0.01, # Equal recombination and mutation
np.zeros(m) + 0.2, # Extreme
np.zeros(m) + 1e-6, # Extreme
np.random.rand(m) * 0.2, # Random
]
for s, r, mu in itertools.product(haplotypes, rs, mus):
r[0] = 0
# Must be calculated from the genotype matrix,
Expand All @@ -101,8 +113,7 @@ def get_examples_pars_haploid(self, ts, scale_mutation=True, seed=42):
yield n, m, H, s, e, r, mu

# Diploid
def get_examples_diploid(self, ts, seed=42):
np.random.seed(seed)
def get_examples_diploid(self, ts):
H = ts.genotype_matrix()
s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0])
H = H[:, 2:]
Expand All @@ -116,9 +127,10 @@ def get_examples_diploid(self, ts, seed=42):
s_miss_mid[0, ts.num_sites // 2] = core.MISSING
s_miss_all = s.copy()
s_miss_all[0, :] = core.MISSING
genotypes.append(s_miss_last)
genotypes.append(s_miss_mid)
genotypes.append(s_miss_all)
# FIXME Handle MISSING properly.
#genotypes.append(s_miss_last)
#genotypes.append(s_miss_mid)
#genotypes.append(s_miss_all)
m = ts.num_sites
n = H.shape[1]
G = np.zeros((m, n, n))
Expand All @@ -136,42 +148,43 @@ def get_emission_matrix_diploid(self, mu, m):
e[:, core.MISSING_INDEX] = 1
return e

def get_examples_pars_diploid(self, ts, seed=42):
def get_examples_pars_diploid(
self,
ts,
mean_r=None,
mean_mu=None,
seed=42
):
"""Returns an iterator over combinations of examples and parameters."""
np.random.seed(seed)
H, G, genotypes = self.get_examples_diploid(ts)
m = ts.num_sites
n = H.shape[1]
rs = [
np.zeros(m) + 0.01, # Equal recombination and mutation
np.zeros(m) + 0.999, # Extreme
np.zeros(m) + 1e-6, # Extreme
np.random.rand(m), # Random
]
mus = [
np.zeros(m) + 0.01, # Equal recombination and mutation
np.zeros(m) + 0.33, # Extreme
np.zeros(m) + 1e-6, # Extreme
np.random.rand(m) * 0.33, # Random
]
if mean_r is not None and mean_mu is not None:
rs = [
mean_r * (np.random.rand(m) + 0.5) / 2
]
mus = [
mean_mu * (np.random.rand(m) + 0.5) / 2
]
else:
rs = [
np.zeros(m) + 0.01, # Equal recombination and mutation
np.zeros(m) + 0.999, # Extreme
np.zeros(m) + 1e-6, # Extreme
np.random.rand(m), # Random
]
mus = [
np.zeros(m) + 0.01, # Equal recombination and mutation
np.zeros(m) + 0.33, # Extreme
np.zeros(m) + 1e-6, # Extreme
np.random.rand(m) * 0.33, # Random
]
for s, r, mu in itertools.product(genotypes, rs, mus):
r[0] = 0
e = self.get_emission_matrix_diploid(mu, m)
yield n, m, G, s, e, r, mu

def get_examples_pars_larger_diploid(self, ts, mean_r=1e-5, mean_mu=1e-5, seed=42):
"""Returns an iterator over combinations of examples and parameters."""
np.random.seed(seed)
H, G, genotypes = self.get_examples_diploid(ts)
m = H.shape[0]
n = H.shape[1]
r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2)
r[0] = 0
mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2)
e = self.get_emission_matrix_diploid(mu, m)
for s in genotypes:
yield n, m, G, s, e, r, mu

# Prepare simple example datasets.
def get_simple_n10_no_recombination(self, seed=42):
ts = msprime.simulate(
Expand Down Expand Up @@ -291,10 +304,10 @@ def get_multiallelic_n16(self, seed=42):
return ts

# Prepare a larger example dataset.
def get_large(self, n=50, length=1e5, mean_r=1e-5, mean_mu=1e-5, seed=42):
def get_larger(self, num_samples, seq_length, mean_r, mean_mu, seed=42):
ts = msprime.simulate(
n + 1,
length=length,
num_samples + 1,
length=seq_length,
mutation_rate=mean_mu,
recombination_rate=mean_r,
random_seed=seed,
Expand Down
Loading

0 comments on commit 5789876

Please sign in to comment.