Skip to content

Commit

Permalink
Allow ancient samples
Browse files Browse the repository at this point in the history
Rework build-prior and inside / outside logic to allow historical samples

And speed up time constraint algorithms while also allowing nodes to be out of time order
  • Loading branch information
awohns authored and hyanwong committed Jan 9, 2023
1 parent 38493ff commit 6e6cf7b
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 98 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tskit>=0.4.0
tskit>=0.5.2
tsinfer>=0.3.0
flake8
numpy
Expand Down
17 changes: 3 additions & 14 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ def test_dangling_fails(self):
print(ts.draw_text())
print("Samples:", ts.samples())
Ne = 0.5
with pytest.raises(ValueError, match="simplified"):
with pytest.raises(ValueError, match="simplify"):
tsdate.build_prior_grid(ts, Ne, timepoints=np.array([0, 1.2, 2]))
# mut_rate = 1
# eps = 1e-6
Expand Down Expand Up @@ -1421,7 +1421,7 @@ def test_date_input(self):

def test_sample_as_parent_fails(self):
ts = utility_functions.single_tree_ts_n3_sample_as_parent()
with pytest.raises(NotImplementedError):
with pytest.raises(ValueError, match="samples at non-zero times"):
tsdate.date(ts, mutation_rate=None, Ne=1)

def test_recombination_not_implemented(self):
Expand Down Expand Up @@ -1532,18 +1532,7 @@ def test_constrain_ages_topo(self):
ts = utility_functions.two_tree_ts()
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
eps = 1e-6
nodes_to_date = np.array([3, 4, 5])
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
assert np.array_equal(
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
)

def test_constrain_ages_topo_no_nodes_to_date(self):
ts = utility_functions.two_tree_ts()
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
eps = 1e-6
nodes_to_date = None
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
constrained_ages = constrain_ages_topo(ts, post_mn, eps)
assert np.array_equal(
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
)
Expand Down
19 changes: 16 additions & 3 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_bad_Ne(self):

def test_dangling_failure(self):
ts = utility_functions.single_tree_ts_n2_dangling()
with pytest.raises(ValueError, match="simplified"):
with pytest.raises(ValueError, match="simplify"):
tsdate.date(ts, mutation_rate=None, Ne=1)

def test_unary_failure(self):
Expand Down Expand Up @@ -271,16 +271,29 @@ def test_fails_multi_root(self):
with pytest.raises(ValueError):
tsdate.date(multiroot_ts, Ne=1, mutation_rate=2, priors=good_priors)

def test_non_contemporaneous(self):
def test_non_contemporaneous_warn(self):
samples = [
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=1.0),
]
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=12)
with pytest.raises(NotImplementedError):
with pytest.raises(ValueError, match="samples at non-zero times"):
tsdate.date(ts, Ne=1, mutation_rate=2)
with pytest.raises(ValueError, match="samples at non-zero times"):
tsdate.build_prior_grid(ts, Ne=1)

def test_non_contemporaneous(self):
samples = [
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=1.0),
]
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=12)
priors = tsdate.build_prior_grid(ts, Ne=1, allow_historical_samples=True)
tsdate.date(ts, priors=priors, mutation_rate=2)

def test_no_mutation_times(self):
ts = msprime.simulate(20, Ne=1, mutation_rate=1, random_seed=12)
Expand Down
16 changes: 12 additions & 4 deletions tsdate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def __init__(
] = (-np.arange(num_nodes - self.num_nonfixed) - 1)
self.probability_space = LIN

def fixed_node_ids(self):
return np.where(self.row_lookup < 0)[0]

def nonfixed_node_ids(self):
return np.where(self.row_lookup >= 0)[0]

def force_probability_space(self, probability_space):
"""
probability_space can be "logarithmic" or "linear": this function will force
Expand Down Expand Up @@ -140,6 +146,9 @@ def normalize(self):
else:
raise RuntimeError("Probability space is not", LIN, "or", LOG)

def is_fixed(self, node_id):
return self.row_lookup[node_id] < 0

def __getitem__(self, node_id):
index = self.row_lookup[node_id]
if index < 0:
Expand Down Expand Up @@ -207,8 +216,7 @@ def fill_fixed(orig, fixed_data):
new_obj.fixed_data = fill_fixed(
self, grid_data if fixed_data is None else fixed_data
)
if probability_space is None:
new_obj.probability_space = self.probability_space
else:
new_obj.probability_space = probability_space
new_obj.probability_space = self.probability_space
if probability_space is not None:
new_obj.force_probability_space(probability_space)
return new_obj
127 changes: 77 additions & 50 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
"""
ll = scipy.stats.poisson.pmf(muts, dt * mutation_rate * span)
if normalize:
return ll / np.max(ll)
return ll / np.nanmax(ll)
else:
return ll

Expand Down Expand Up @@ -258,15 +258,28 @@ def get_mut_lik_fixed_node(self, edge):

mutations_on_edge = self.mut_edges[edge.id]
child_time = self.ts.node(edge.child).time
#assert child_time == 0
# Temporary hack - we should really take a more precise likelihood
return self._lik(
mutations_on_edge,
edge.span,
self.timediff,
self.mut_rate,
normalize=self.normalize,
)
if child_time == 0:
return self._lik(
mutations_on_edge,
edge.span,
self.timediff,
self.mut_rate,
normalize=self.normalize,
)
else:
timediff = self.timepoints - child_time + 1e-8
# Temporary hack - we should really take a more precise likelihood
likelihood = self._lik(
mutations_on_edge,
edge.span,
timediff,
self.mut_rate,
normalize=self.normalize,
)
# Prevent child from being older than parent
likelihood[timediff < 0] = 0

return likelihood

def get_mut_lik_lower_tri(self, edge):
"""
Expand Down Expand Up @@ -389,7 +402,7 @@ def get_fixed(self, arr, edge):
return arr * liks

def scale_geometric(self, fraction, value):
return value**fraction
return value ** fraction


class LogLikelihoods(Likelihoods):
Expand Down Expand Up @@ -429,7 +442,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
"""
ll = scipy.stats.poisson.logpmf(muts, dt * mutation_rate * span)
if normalize:
return ll - np.max(ll)
return ll - np.nanmax(ll)
else:
return ll

Expand Down Expand Up @@ -634,11 +647,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
inside = self.priors.clone_with_new_data( # store inside matrix values
grid_data=np.nan, fixed_data=self.lik.identity_constant
)
# It is possible that a simple node is non-fixed, in which case we want to
# provide an inside array that reflects the prior distribution
nonfixed_samples = np.intersect1d(inside.nonfixed_node_ids(), self.ts.samples())
for u in nonfixed_samples:
# this is in the same probability space as the prior, so we should be
# OK just to copy the prior values straight in. It's unclear to me (Yan)
# how/if they should be normalised, however
inside[u][:] = self.priors[u]

if cache_inside:
g_i = np.full(
(self.ts.num_edges, self.lik.grid_size), self.lik.identity_constant
)
norm = np.full(self.ts.num_nodes, np.nan)
to_visit = np.zeros(self.ts.num_nodes, dtype=bool)
to_visit[inside.nonfixed_node_ids()] = True
# Iterate through the nodes via groupby on parent node
for parent, edges in tqdm(
self.edges_by_parent_asc(),
Expand Down Expand Up @@ -673,14 +697,23 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
"dangling nodes: please simplify it"
)
daughter_val = self.lik.scale_geometric(
spanfrac, self.lik.make_lower_tri(inside[edge.child])
spanfrac, self.lik.make_lower_tri(inside_values)
)
edge_lik = self.lik.get_inside(daughter_val, edge)
val = self.lik.combine(val, edge_lik)
if np.all(val == 0):
raise ValueError
if cache_inside:
g_i[edge.id] = edge_lik
norm[parent] = np.max(val) if normalize else 1
norm[parent] = np.max(val) if normalize else self.lik.identity_constant
inside[parent] = self.lik.reduce(val, norm[parent])
to_visit[parent] = False

# There may be nodes that are not parents but are also not fixed (e.g.
# undated sample nodes). These need an identity normalization constant
for unfixed_unvisited in np.where(to_visit)[0]:
norm[unfixed_unvisited] = self.lik.identity_constant

if cache_inside:
self.g_i = self.lik.reduce(g_i, norm[self.ts.tables.edges.child, None])
# Keep the results in this object
Expand Down Expand Up @@ -732,10 +765,10 @@ def outside_pass(
if ignore_oldest_root:
if edge.parent == self.ts.num_nodes - 1:
continue
#if edge.parent in self.fixednodes:
# raise RuntimeError(
# "Fixed nodes cannot currently be parents in the TS"
# )
if edge.parent in self.fixednodes:
raise RuntimeError(
"Fixed nodes cannot currently be parents in the TS"
)
# Geometric scaling works exactly for all nodes fixed in graph
# but is an approximation when times are unknown.
spanfrac = edge.span / self.spans[child]
Expand Down Expand Up @@ -897,34 +930,32 @@ def posterior_mean_var(ts, posterior, *, fixed_node_set=None):
return ts, mn_post, vr_post


def constrain_ages_topo(ts, post_mn, eps, nodes_to_date=None, progress=False):
def constrain_ages_topo(ts, node_times, eps, progress=False):
"""
If predicted node times violate topology, restrict node ages so that they
must be older than all their children.
If node_times violate topology, return increased node_times so that each node is
guaranteed to be older than any of its their children.
"""
new_mn_post = np.copy(post_mn)
if nodes_to_date is None:
nodes_to_date = np.arange(ts.num_nodes, dtype=np.uint64)
nodes_to_date = nodes_to_date[~np.isin(nodes_to_date, ts.samples())]

tables = ts.tables
parents = tables.edges.parent
nd_children = tables.edges.child[np.argsort(parents)]
parents = sorted(parents)
parents_unique = np.unique(parents, return_index=True)
parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)]
for index, nd in tqdm(
enumerate(sorted(nodes_to_date)), desc="Constrain Ages", disable=not progress
edges_parent = ts.edges_parent
edges_child = ts.edges_child

new_node_times = np.copy(node_times)
# Traverse through the ARG, ensuring children come before parents.
# This can be done by iterating over groups of edges with the same parent
new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1
for edges_start, edges_end in tqdm(
zip(
itertools.chain([0], new_parent_edge_idx),
itertools.chain(new_parent_edge_idx, [len(edges_parent)]),
),
desc="Constrain Ages",
disable=not progress,
):
if index + 1 != len(nodes_to_date):
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
else:
children_index = np.arange(parent_indices[index], ts.num_edges)
children = nd_children[children_index]
time = np.max(new_mn_post[children])
if new_mn_post[nd] <= time:
new_mn_post[nd] = time + eps
return new_mn_post
parent = edges_parent[edges_start]
child_ids = edges_child[edges_start:edges_end] # May contain dups
oldest_child_time = np.max(new_node_times[child_ids])
if oldest_child_time >= new_node_times[parent]:
new_node_times[parent] = oldest_child_time + eps
return new_node_times


def date(
Expand Down Expand Up @@ -1015,7 +1046,7 @@ def date(
progress=progress,
**kwargs
)
constrained = constrain_ages_topo(tree_sequence, dates, eps, nds, progress)
constrained = constrain_ages_topo(tree_sequence, dates, eps, progress)
tables = tree_sequence.dump_tables()
tables.time_units = time_units
tables.nodes.time = constrained
Expand Down Expand Up @@ -1064,12 +1095,6 @@ def get_dates(
:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
"""
# Stuff yet to be implemented. These can be deleted once fixed
#for sample in tree_sequence.samples():
# if tree_sequence.node(sample).time != 0:
# raise NotImplementedError("Samples must all be at time 0")
fixed_nodes = set(tree_sequence.samples())

# Default to not creating approximate priors unless ts has > 1000 samples
approx_priors = False
if tree_sequence.num_samples > 1000:
Expand Down Expand Up @@ -1097,6 +1122,8 @@ def get_dates(
)
priors = priors

fixed_nodes = set(priors.fixed_node_ids())

if probability_space != base.LOG:
liklhd = Likelihoods(
tree_sequence,
Expand Down
Loading

0 comments on commit 6e6cf7b

Please sign in to comment.