Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to only storing fixed node data in Prior #123

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 108 additions & 76 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import tsinfer

import tsdate
from tsdate.date import (SpansBySamples, PriorParams,
from tsdate.date import (SpansBySamples, PriorParams, LIN, LOG,
ConditionalCoalescentTimes, fill_prior, Likelihoods,
LogLikelihoods, LogLikelihoodsStreaming, InOutAlgorithms,
NodeGridValues, gamma_approx, constrain_ages_topo) # NOQA
Expand Down Expand Up @@ -189,12 +189,19 @@ def test_larger_find_node_tip_weights(self):
self.verify_weights(ts)

def test_dangling_nodes_warn(self):
ts = utility_functions.single_tree_ts_n3_dangling()
ts = utility_functions.single_tree_ts_n2_dangling()
with self.assertLogs(level="WARNING") as log:
self.verify_weights(ts)
self.assertGreater(len(log.output), 0)
self.assertIn("dangling", log.output[0])

def test_simple_non_contemporaneous(self):
ts = utility_functions.two_tree_ts_n3_non_contemporaneous()
n = len([s for s in ts.samples() if ts.node(s).time == 0])
span_data = self.verify_weights(ts)
self.assertEqual(span_data.lookup_weight(4, n, 2), 0.2) # 2 contemporanous tips
self.assertEqual(span_data.lookup_weight(4, n, 1), 0.8) # only 1 contemporanous

@unittest.skip("YAN to fix")
def test_truncated_nodes(self):
Ne = 1e2
Expand Down Expand Up @@ -337,9 +344,10 @@ class TestMixturePrior(unittest.TestCase):
def get_mixture_prior_params(self, ts, prior_distr):
span_data = SpansBySamples(ts)
priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr)
priors.add(ts.num_samples, approximate=False)
for total_fixed in span_data.total_fixed_at_0_counts:
priors.add(total_fixed, approximate=False)
mixture_prior = priors.get_mixture_prior_params(span_data)
return(mixture_prior)
return mixture_prior

def test_one_tree_n2(self):
ts = utility_functions.single_tree_ts_n2()
Expand Down Expand Up @@ -420,12 +428,30 @@ def test_two_tree_mutation_ts(self):
self.assertTrue(
np.allclose(mixture_prior[5, self.alpha_beta], [1.6, 1.2]))

def test_simple_non_contemporaneous(self):
ts = utility_functions.two_tree_ts_n3_non_contemporaneous()
mixture_prior = self.get_mixture_prior_params(ts, 'gamma')
self.assertTrue(
np.allclose(mixture_prior[4, self.alpha_beta], [0.11111, 0.55555]))

def test_simulated_non_contemporaneous(self):
samples = [
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=1.0)
]
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=123)
self.get_mixture_prior_params(ts, 'lognorm')
self.get_mixture_prior_params(ts, 'gamma')


class TestPriorVals(unittest.TestCase):
def verify_prior_vals(self, ts, prior_distr):
span_data = SpansBySamples(ts)
priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr)
priors.add(ts.num_samples, approximate=False)
for total_fixed in span_data.total_fixed_at_0_counts:
priors.add(total_fixed, approximate=False)
grid = np.linspace(0, 3, 3)
mixture_prior = priors.get_mixture_prior_params(span_data)
prior_vals = fill_prior(mixture_prior, grid, ts, prior_distr=prior_distr)
Expand Down Expand Up @@ -470,6 +496,23 @@ def test_tree_with_unary_nodes(self):
self.assertTrue(np.allclose(prior_vals[4], [0, 1, 0.093389]))
self.assertTrue(np.allclose(prior_vals[3], [0, 1, 0.011109]))

def test_simple_non_contemporaneous(self):
ts = utility_functions.two_tree_ts_n3_non_contemporaneous()
prior_vals = self.verify_prior_vals(ts, 'gamma')
self.assertEqual(prior_vals.fixed_time(2), ts.node(2).time)

def test_simulated_non_contemporaneous(self):
samples = [
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=1.0)
]
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=123)
prior_vals = self.verify_prior_vals(ts, 'gamma')
print(prior_vals.timepoints)
raise


class TestLikelihoodClass(unittest.TestCase):
def poisson(self, l, x, normalize=True):
Expand Down Expand Up @@ -671,102 +714,91 @@ def test_logsumexp_streaming(self):


class TestNodeGridValuesClass(unittest.TestCase):
# TODO - needs a few more tests in here
def test_init(self):
num_nodes = 5
ids = np.array([3, 4])
nonfixed_ids = np.array([3, 2])
timepoints = np.array(range(10))
store = NodeGridValues(num_nodes, ids, timepoints, fill_value=6)
self.assertEquals(store.grid_data.shape, (len(ids), len(timepoints)))
self.assertEquals(len(store.fixed_data), (num_nodes-len(ids)))
store = NodeGridValues(timepoints, gridnodes=nonfixed_ids, fill_value=6)
self.assertEquals(store.grid_data.shape, (len(nonfixed_ids), len(timepoints)))
self.assertTrue(np.all(store.grid_data == 6))
self.assertTrue(np.all(store.fixed_data == 6))
for i in range(np.max(nonfixed_ids)+1):
if i in nonfixed_ids:
self.assertTrue(np.all(store[i] == 6))
else:
with self.assertRaises(IndexError):
_ = store[i]

ids = np.array([3, 4], dtype=np.int32)
store = NodeGridValues(num_nodes, ids, timepoints, fill_value=5)
self.assertEquals(store.grid_data.shape, (len(ids), len(timepoints)))
self.assertEquals(len(store.fixed_data), num_nodes-len(ids))
self.assertTrue(np.all(store.fixed_data == 5))
def test_probability_spaces(self):
nonfixed_ids = np.array([3, 4])
timepoints = np.array(range(10))
store = NodeGridValues(timepoints, gridnodes=nonfixed_ids, fill_value=0.5)
self.assertTrue(np.all(store.grid_data == 0.5))
store.force_probability_space(LIN)
self.assertTrue(np.all(store.grid_data == 0.5))
store.force_probability_space(LOG)
self.assertTrue(np.allclose(store.grid_data, np.log(0.5)))
store.force_probability_space(LOG)
self.assertTrue(np.allclose(store.grid_data, np.log(0.5)))
store.force_probability_space(LIN)
self.assertTrue(np.all(store.grid_data == 0.5))
self.assertRaises(ValueError, store.force_probability_space, "foobar")

def test_set_and_get(self):
num_nodes = 5
grid_size = 2
timepoints = [0, 1.1]
fill = {}
for ids in ([3, 4], []):
for nonfixed_ids in ([3, 4], [0]):
np.random.seed(1)
store = NodeGridValues(
num_nodes, np.array(ids, dtype=np.int32), np.array(range(grid_size)))
for i in range(num_nodes):
fill[i] = np.random.random(grid_size if i in ids else None)
store[i] = fill[i]
for i in range(num_nodes):
store = NodeGridValues(timepoints, gridnodes=nonfixed_ids)
for i in range(5):
fill[i] = np.random.random(len(store.timepoints))
if i in nonfixed_ids:
store[i] = fill[i]
else:
with self.assertRaises(IndexError):
store[i] = fill[i]
for i in nonfixed_ids:
self.assertTrue(np.all(fill[i] == store[i]))
self.assertRaises(IndexError, store.__getitem__, num_nodes)

def test_bad_init(self):
ids = [3, 4]
self.assertRaises(ValueError, NodeGridValues, 3, np.array(ids),
np.array([0, 1.2, 2]))
self.assertRaises(AttributeError, NodeGridValues, 5, np.array(ids), -1)
self.assertRaises(ValueError, NodeGridValues, 5, np.array([-1]),
np.array([0, 1.2, 2]))
timepoints = [0, 1.2, 2]
nonfixed_ids = [4, 0]
NodeGridValues(timepoints, gridnodes=nonfixed_ids)
# duplicate ids
self.assertRaises(ValueError, NodeGridValues, timepoints, gridnodes=[4, 4, 0])
# bad ids
self.assertRaises(
ValueError, NodeGridValues, timepoints, gridnodes=np.array([[1, 4], [2, 0]]))
self.assertRaises(OverflowError, NodeGridValues, timepoints, gridnodes=[-1, 4])
# bad timepoint
self.assertRaises(ValueError, NodeGridValues, [], gridnodes=nonfixed_ids)

def test_clone(self):
num_nodes = 10
grid_size = 2
ids = [3, 4]
orig = NodeGridValues(num_nodes, np.array(ids), np.array(range(grid_size)))
timepoints = [0, 1]
nonfixed_ids = [3, 4]
orig = NodeGridValues(timepoints, gridnodes=nonfixed_ids)
orig[3] = np.array([1, 2])
orig[4] = np.array([4, 3])
orig[0] = 1.5
orig[9] = 2.5
# test with np.zeros
clone = NodeGridValues.clone_with_new_data(orig, 0)
clone = orig.clone_grid_with_new_data(0)
self.assertEquals(clone.grid_data.shape, orig.grid_data.shape)
self.assertEquals(clone.fixed_data.shape, orig.fixed_data.shape)
self.assertTrue(np.all(clone.grid_data == 0))
self.assertTrue(np.all(clone.fixed_data == 0))
# test with something else
clone = NodeGridValues.clone_with_new_data(orig, 5)
clone = orig.clone_grid_with_new_data(5)
self.assertEquals(clone.grid_data.shape, orig.grid_data.shape)
self.assertEquals(clone.fixed_data.shape, orig.fixed_data.shape)
self.assertTrue(np.all(clone.grid_data == 5))
self.assertTrue(np.all(clone.fixed_data == 5))
# test with different
scalars = np.arange(num_nodes - len(ids))
clone = NodeGridValues.clone_with_new_data(orig, 0, scalars)
self.assertEquals(clone.grid_data.shape, orig.grid_data.shape)
self.assertEquals(clone.fixed_data.shape, orig.fixed_data.shape)
self.assertTrue(np.all(clone.grid_data == 0))
self.assertTrue(np.all(clone.fixed_data == scalars))

clone = NodeGridValues.clone_with_new_data(
orig, np.array([[1, 2], [4, 3]]))
for i in range(num_nodes):
if i in ids:
self.assertTrue(np.all(clone[i] == orig[i]))
else:
self.assertTrue(np.isnan(clone[i]))
clone = NodeGridValues.clone_with_new_data(
orig, np.array([[1, 2], [4, 3]]), 0)
for i in range(num_nodes):
if i in ids:
clone = orig.clone_grid_with_new_data(np.array([[1, 2], [4, 3]]))
for i in range(np.max(nonfixed_ids)+1):
if i in nonfixed_ids:
self.assertTrue(np.all(clone[i] == orig[i]))
else:
self.assertEquals(clone[i], 0)
self.assertRaises(IndexError, clone.__getitem__, i)

def test_bad_clone(self):
num_nodes = 10
ids = [3, 4]
orig = NodeGridValues(num_nodes, np.array(ids), np.array([0, 1.2]))
self.assertRaises(
ValueError,
NodeGridValues.clone_with_new_data,
orig, np.array([[1, 2, 3], [4, 5, 6]]))
ids = np.array([3, 4])
timepoints = np.array([0, 1.2])
orig = NodeGridValues(timepoints, gridnodes=ids)
self.assertRaises(
ValueError,
NodeGridValues.clone_with_new_data,
orig, 0, np.array([[1, 2], [4, 5]]))
ValueError, orig.clone_grid_with_new_data, np.array([[1, 2, 3], [4, 5, 6]]))


class TestAlgorithmClass(unittest.TestCase):
Expand All @@ -780,7 +812,7 @@ def test_nonmatching_prior_vs_lik_timepoints(self):

def test_nonmatching_prior_vs_lik_fixednodes(self):
ts1 = utility_functions.single_tree_ts_n3()
ts2 = utility_functions.single_tree_ts_n3_dangling()
ts2 = utility_functions.single_tree_ts_n2_dangling()
timepoints = np.array([0, 1.2, 2])
prior = tsdate.build_prior_grid(ts1, timepoints)
lls = Likelihoods(ts2, prior.timepoints)
Expand Down Expand Up @@ -892,7 +924,7 @@ def test_two_tree_mutation_ts(self):
self.assertTrue(np.allclose(algo.inside[5], np.array([0, 7.06320034e-11, 1])))

def test_dangling_fails(self):
ts = utility_functions.single_tree_ts_n3_dangling()
ts = utility_functions.single_tree_ts_n2_dangling()
print(ts.draw_text())
print("Samples:", ts.samples())
prior = tsdate.build_prior_grid(ts, timepoints=np.array([0, 1.2, 2]))
Expand Down
13 changes: 10 additions & 3 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TestPrebuilt(unittest.TestCase):
Tests for tsdate on prebuilt tree sequences
"""
def test_dangling_failure(self):
ts = utility_functions.single_tree_ts_n3_dangling()
ts = utility_functions.single_tree_ts_n2_dangling()
self.assertRaisesRegexp(ValueError, "dangling", tsdate.date, ts, Ne=1)

def test_unary_warning(self):
Expand All @@ -48,7 +48,7 @@ def test_unary_warning(self):
self.assertEqual(len(log.output), 1)
self.assertIn("unary nodes", log.output[0])

def test_fails_with_recombination(self):
def test_fails_with_recombination_clock(self):
ts = utility_functions.two_tree_mutation_ts()
for probability_space in (LOG, LIN):
self.assertRaises(
Expand All @@ -58,6 +58,12 @@ def test_fails_with_recombination(self):
NotImplementedError, tsdate.date, ts, Ne=1, recombination_rate=1,
probability_space=probability_space, mutation_rate=1)

def test_non_contemporaneous(self):
ts = utility_functions.two_tree_ts_n3_non_contemporaneous()
theta = 2
ts = msprime.mutate(ts, rate=theta)
tsdate.date(ts, Ne=1, mutation_rate=theta, probability_space=LIN)

# def test_simple_ts_n2(self):
# ts = utility_functions.single_tree_ts_n2()
# dated_ts = tsdate.date(ts, Ne=10000)
Expand Down Expand Up @@ -209,7 +215,8 @@ def test_non_contemporaneous(self):
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=1.0)
]
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2)
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=123)
print(ts.draw_text())
self.assertRaises(NotImplementedError, tsdate.date, ts, 1, 2)

@unittest.skip("YAN to fix")
Expand Down
Loading