Skip to content

Commit

Permalink
Major reorganisation
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 19, 2024
1 parent e82e151 commit f1b99b3
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 246 deletions.
2 changes: 1 addition & 1 deletion lshmm/vit_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r):


def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r):
"""Fully vectorised naive LS diploid Viterbi algorithm using numpy."""
"""Fully vectorised naive implementation using Numpy."""
char_both = np.eye(n * n).ravel().reshape((n, n, n, n))
char_col = np.tile(np.sum(np.eye(n * n).reshape((n, n, n, n)), 3), (n, 1, 1, 1))
char_row = np.copy(char_col).T
Expand Down
319 changes: 319 additions & 0 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
import itertools

import numpy as np

import msprime

import lshmm.core as core


class LSBase:
"""Base class of Li and Stephens tests."""

def get_num_alleles(self, ref_haps, query):
assert ref_haps.shape[0] == query.shape[1]
num_sites = ref_haps.shape[0]
num_alleles = np.zeros(num_sites, dtype=np.int8)
exclusion_set = np.array([core.MISSING])
for i in range(num_sites):
uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i]))
num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set))
assert np.all(num_alleles >= 0), "Number of alleles cannot be zero."
return num_alleles

# Haploid
def get_examples_haploid(self, ts):
H = ts.genotype_matrix()
s = H[:, 0].reshape(1, H.shape[0])
H = H[:, 1:]

haplotypes = [s, H[:, -1].reshape(1, H.shape[0])]
s_tmp = s.copy()
s_tmp[0, -1] = core.MISSING
haplotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, ts.num_sites // 2] = core.MISSING
haplotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, :] = core.MISSING
haplotypes.append(s_tmp)

return H, haplotypes

def get_emission_prob_matrix_haploid(
self, mu, m, n_alleles, scale_mutation_based_on_n_alleles
):
e = np.zeros((m, 2))
if isinstance(mu, float):
mu = mu * np.ones(m)

if scale_mutation_based_on_n_alleles:
e[:, 0] = mu - mu * np.equal(
n_alleles, np.ones(m)
) # Add boolean in case we're at an invariant site
e[:, 1] = 1 - (n_alleles - 1) * mu
else:
for j in range(m):
if n_alleles[j] == 1:
# In case we're at an invariant site
e[j, 0] = 0
e[j, 1] = 1
else:
e[j, 0] = mu[j] / (n_alleles[j] - 1)
e[j, 1] = 1 - mu[j]
return e

def get_examples_parameters_haploid(self, ts, scale_mutation=True, seed=42):
"""
Returns an iterator over combinations of haplotypes, recombination probabilties,
and mutation probabilities.
"""
np.random.seed(seed)
H, haplotypes = self.get_examples_haploid(ts)
n = H.shape[1]
m = ts.get_num_sites()

# Here we have equal mutation and recombination
r = np.zeros(m) + 0.01
mu = np.zeros(m) + 0.01
r[0] = 0

for s in haplotypes:
# Must be calculated from the genotype matrix because we can now get back mutations that
# result in the number of alleles being higher than the number of alleles in the reference panel.
n_alleles = self.get_num_alleles(H, s)
e = self.get_emission_prob_matrix_haploid(
mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation
)
yield n, m, H, s, e, r, mu

# Mixture of random and extremes
rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)]
mus = [np.zeros(m) + 0.2, np.zeros(m) + 1e-6, np.random.rand(m) * 0.2]

for s, r, mu in itertools.product(haplotypes, rs, mus):
r[0] = 0
n_alleles = self.get_num_alleles(H, s)
e = self.get_emission_prob_matrix_haploid(
mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation
)
yield n, m, H, s, e, r, mu

# Diploid
def get_examples_diploid(self, ts, seed=42):
np.random.seed(seed)
H = ts.genotype_matrix()
s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0])
H = H[:, 2:]

genotypes = [
s,
H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]),
]

s_tmp = s.copy()
s_tmp[0, -1] = core.MISSING
genotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, ts.num_sites // 2] = core.MISSING
genotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, :] = core.MISSING
genotypes.append(s_tmp)

m = ts.get_num_sites()
n = H.shape[1]

G = np.zeros((m, n, n))
for i in range(m):
G[i, :, :] = np.add.outer(H[i, :], H[i, :])

return H, G, genotypes

def get_emission_prob_matrix_diploid(self, mu, m):
e = np.zeros((m, 8))
e[:, core.EQUAL_BOTH_HOM] = (1 - mu) ** 2
e[:, core.UNEQUAL_BOTH_HOM] = mu**2
e[:, core.BOTH_HET] = (1 - mu) ** 2 + mu**2
e[:, core.REF_HOM_OBS_HET] = 2 * mu * (1 - mu)
e[:, core.REF_HET_OBS_HOM] = mu * (1 - mu)
e[:, core.MISSING_INDEX] = 1
return e

def get_examples_parameters_diploid(self, ts, seed=42):
np.random.seed(seed)
H, G, genotypes = self.example_genotypes(ts)
n = H.shape[1]
m = ts.get_num_sites()

# Here we have equal mutation and recombination
r = np.zeros(m) + 0.01
mu = np.zeros(m) + 0.01
r[0] = 0

e = self.genotype_emission(mu, m)

for s in genotypes:
yield n, m, G, s, e, r, mu

# Mixture of random and extremes
rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)]
mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33]

e = self.genotype_emission(mu, m)

for s, r, mu in itertools.product(genotypes, rs, mus):
r[0] = 0
e = self.genotype_emission(mu, m)
yield n, m, G, s, e, r, mu

def get_examples_parameters_larger_diploid(
self, ts, mean_r=1e-5, mean_mu=1e-5, seed=42
):
np.random.seed(seed)
H, G, genotypes = self.example_genotypes(ts)

m = ts.get_num_sites()
n = H.shape[1]

r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2)
r[0] = 0

mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2)

e = self.genotype_emission(mu, m)

for s in genotypes:
yield n, m, G, s, e, r, mu

# Prepare simple example datasets.
def get_simple_n_10_no_recombination(self, seed=42):
ts = msprime.simulate(
10,
recombination_rate=0,
mutation_rate=0.5,
random_seed=seed,
)
assert ts.num_sites > 3
return ts

def get_simple_n_6(self, seed=42):
ts = msprime.simulate(
6,
recombination_rate=2,
mutation_rate=7,
random_seed=seed,
)
assert ts.num_sites > 5
return ts

def get_simple_n_8(self, seed=42):
ts = msprime.simulate(
8,
recombination_rate=2,
mutation_rate=5,
random_seed=seed,
)
assert ts.num_sites > 5
return ts

def get_simple_n_8_high_recombination(self, seed=42):
ts = msprime.simulate(
8,
recombination_rate=20,
mutation_rate=5,
random_seed=seed,
)
assert ts.num_trees > 15
assert ts.num_sites > 5
return ts

def get_simple_n_16(self, seed=42):
ts = msprime.simulate(
16,
recombination_rate=2,
mutation_rate=5,
random_seed=seed,
)
assert ts.num_sites > 5
return ts

# Prepare example datasets with multiallelic sites.
def get_multiallelic_n_10_no_recombination(self, seed=42):
ts = msprime.sim_ancestry(
samples=10,
recombination_rate=0,
sequence_length=10,
population_size=1e4,
random_seed=seed,
)
ts = msprime.sim_mutations(
ts,
rate=1e-5,
random_seed=seed,
)
assert ts.num_sites > 3
return ts

def get_multiallelic_n_6(self, seed=42):
ts = msprime.sim_ancestry(
samples=6,
recombination_rate=1e-4,
sequence_length=40,
population_size=1e4,
random_seed=seed,
)
ts = msprime.sim_mutations(
ts,
rate=1e-3,
random_seed=seed,
)
assert ts.num_sites > 5
return ts

def get_multiallelic_n_8(self, seed=42):
ts = msprime.sim_ancestry(
samples=8,
recombination_rate=1e-4,
sequence_length=20,
population_size=1e4,
random_seed=seed,
)
ts = msprime.sim_mutations(
ts,
rate=1e-4,
random_seed=seed,
)
assert ts.num_sites > 5
assert ts.num_trees > 15
return ts

def get_multiallelic_n_16(self, seed=42):
ts = msprime.sim_ancestry(
samples=16,
recombination_rate=1e-2,
sequence_length=20,
population_size=1e4,
random_seed=seed,
)
ts = msprime.sim_mutations(
ts,
rate=1e-4,
random_seed=seed,
)
assert ts.num_sites > 5
return ts

def verify(self, ts):
raise NotImplementedError()

def assertAllClose(self, A, B):
assert np.allclose(A, B, rtol=1e-9, atol=0.0)


class FBAlgorithmBase(LSBase):
"""Base for testing forwards-backwards algorithms."""


class ViterbiAlgorithmBase(LSBase):
"""Base for testing Viterbi algoritms."""
Loading

0 comments on commit f1b99b3

Please sign in to comment.