Skip to content

Commit

Permalink
Add support for 'rec_tr_at_dict' in AnnotatedTree
Browse files Browse the repository at this point in the history
Now method 'extract_reconstructed_tree' also populates a new AnnotatedTree class member that holds an attribute transition dictionary, 'rec_tr_at_dict', for the reconstructed tree. This code must find the pruned internal nodes, grab their attribute transitions, and merge to the non-pruned trees of the reconstructed tree.

I also added a unit test for this new code, inside 'test_tree_extract_reconstructed.py'. Plotting functions also had to be adjusted, and they can now plot the reconstructed tree instead of the complete tree if the user asks for it.
  • Loading branch information
binho authored and binho committed Apr 19, 2024
1 parent d0977a4 commit 73f45fc
Show file tree
Hide file tree
Showing 13 changed files with 973 additions and 128 deletions.
2 changes: 1 addition & 1 deletion src/phylojunction/data/attribute_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class AttributeTransition():
from_state: int
to_state: int # if happening at speciation, child 1
to_state2: int # if happening at speciation, child 2
at_speciation: bool
at_speciation: ty.Optional[bool]

# (i) if object stored in an AnnotatedTree's at_dict, this is
# the node subtending a branch (note that this transition may be
Expand Down
314 changes: 284 additions & 30 deletions src/phylojunction/data/tree.py

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions src/phylojunction/data/tree.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class AnnotatedTree(dp.Tree):
tree_reconstructed: dp.Tree
origin_node: ty.Optional[dp.Node]
root_node: ty.Optional[dp.Node]
rec_tr_root_node_label: ty.Optional[str]
brosc_node: ty.Optional[dp.Node]
with_origin: bool
tree_read_as_newick: bool
Expand Down Expand Up @@ -72,7 +73,8 @@ class AnnotatedTree(dp.Tree):
def _populate_node_age_height_dicts(self, unit_branch_lengths: bool = False) -> None: ...
def _prepare_taxon_namespace_for_nexus_printing(self) -> None: ...
def is_extant_or_sa_on_both_sides_complete_tr_root(self, a_node: dp.Node) -> bool: ...
def extract_reconstructed_tree( self, require_obs_both_sides: ty.Optional[bool] = None) -> dp.Tree: ...
def make_at_dict_reflect_rec_tree(self, rec_tree_root_node: dp.Node) -> None: ...
def extract_reconstructed_tree(self, require_obs_both_sides: ty.Optional[bool] = None) -> dp.Tree: ...
def populate_nd_attr_dict(self, attrs_of_interest_list: ty.List[str], attr_dict_added_separately_from_tree: bool = False) -> None: ...
def __str__(self) -> str: ...
def plot_node(self, axes: plt.Axes, node_attr: str = "state", **kwargs) -> None: ...
Expand All @@ -83,7 +85,7 @@ class AnnotatedTree(dp.Tree):
# plotting tree functions
def get_node_name(nd: dp.Node) -> str: ...
def get_x_coord_from_nd_heights(ann_tr: AnnotatedTree, use_age: bool = False, unit_branch_lengths: bool = False) -> ty.Dict[str, float]: ...
def get_y_coord_from_n_obs_nodes(ann_tr: AnnotatedTree, start_at_origin: bool = False, sa_along_branches: bool = True) -> ty.Dict[str, float]: ...
def plot_ann_tree(ann_tr: AnnotatedTree, axes: plt.Axes, use_age: bool = False, start_at_origin: bool = False, attr_of_interest: str = "state", sa_along_branches: bool = True) -> None: ...
def get_y_coord_from_n_obs_nodes(ann_tr: AnnotatedTree, start_at_origin: bool = False, sa_along_branches: bool = True, draw_reconstructed: bool = False) -> ty.Dict[str, float]: ...
def plot_ann_tree(ann_tr: AnnotatedTree, axes: plt.Axes, use_age: bool = False, start_at_origin: bool = False, attr_of_interest: str = "state", sa_along_branches: bool = True, draw_reconstructed: bool = False) -> None: ...
def get_color_map(n_states: int) -> ty.Dict[int, str]: ...
def pj_get_mrca_obs_terminals(a_node: dp.Node, nd_label_list: ty.List[str]) -> dp.Node: ...
112 changes: 82 additions & 30 deletions src/phylojunction/distribution/dn_discrete_sse.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/phylojunction/distribution/dn_discrete_sse.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class DnSSE(pgm.DistrForSampling):
clado_state_transition_dict: ty.Dict[str, AttributeTransition],
untargetable_node_set: ty.Set[str],
cumulative_node_count: int,
node_index: int,
sse_birth_rate_object: sseobj.DiscreteStateDependentRate,
event_t: float,
debug=False) -> ty.Tuple[dp.Node, int]: ...
Expand Down Expand Up @@ -107,6 +108,7 @@ class DnSSE(pgm.DistrForSampling):
sa_lineage_dict: ty.Dict[str, ty.List[SampledAncestor]],
untargetable_node_set: ty.Set[str],
cumulative_sa_count: int,
node_index: int,
event_t: float,
debug: bool = False) -> int: ...
def _execute_event(self,
Expand All @@ -120,6 +122,7 @@ class DnSSE(pgm.DistrForSampling):
untargetable_node_set: ty.Set[str],
cumulative_node_count: int,
cumulative_sa_count: int,
node_index: int,
event_t: float,
debug: bool = False) -> ty.Tuple[dp.Node, int, int]: ...
def _annotate_sampled(self,
Expand Down
5 changes: 3 additions & 2 deletions src/phylojunction/distribution/likelihood/ODE/dn_bisse_ode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as ty
import numpy as np
import math
from scipy.integrate import solve_ivp

__author__ = "Fabio K. Mendes"
Expand Down Expand Up @@ -81,8 +82,8 @@ def solve_bisse_ds_es(ds_es, t_start, t_end, qs, mus, lambdas,
ds_es,
method='RK45',
args=pars,
rtol=1e-4,
atol=1e-4,
rtol=1e-8,
atol=1e-8,
t_eval=[t_end]) # OdeResult object!

ds_es_arr = ds_es.y[:,0]
Expand Down
14 changes: 9 additions & 5 deletions src/phylojunction/distribution/likelihood/ODE/dn_mbt_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ def mbt_bisse_d_eqn(t, ds_es, ds_es_buffer, qs, mus, b):
# to prevent allocating more memory, but this makes the
# numbers be slightly different from those of the pure BiSSE
# ODEs (in the 5th decimal place!)

es = ds_es[2:] # grab last two elements
ds_es_buffer[2:] = mbt_e_eqn(es, qs, mus, b)
# another option
# ds = mbt_e_eqn(es, qs, mus, b)

ds = ds_es[:2] # grab first two elements
ds_es_buffer[:2] = np.ravel(np.matmul(qs, ds) + np.matmul(b, np.kron(es, ds))) \
ds_es_buffer[:2] = np.ravel(
np.matmul(qs, ds) + np.matmul(b, np.kron(es, ds))
) \
+ np.matmul(b, np.kron(ds, es))
# another option
# es = np.ravel(np.matmul(qs, ds) + np.matmul(b, np.kron(es, ds))) \
Expand Down Expand Up @@ -54,10 +55,13 @@ def solve_mbt_ds_es(ds_es, t_start, t_end, ds_es_buffer, qs, mus, b,
ds_es,
method='RK45',
args=pars,
rtol=1e-1,
atol=1e-1,
t_eval=[t_start + t_end])
rtol=1e-10,
atol=1e-10,
t_eval=[0, .00002, .00004, .00006, t_end])
#, t_eval=[t_start + t_end])

print(ds_es.t)
print(ds_es.y)
ds_es_arr = ds_es.y[:,0]

if verbose:
Expand Down
109 changes: 80 additions & 29 deletions src/phylojunction/distribution/likelihood/lik_dn_mbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,23 @@ def recursively_do_node(nd: dp.Node,
nd_age_dict: ty.Dict[str, float],
pars: ty.Tuple[np.matrix, np.ndarray],
n_states: int,
dn_es: np.ndarray,
log_norm_factors: ty.List[float]) -> np.ndarray:

# get parameters
qs, mus, lambdas = pars
t_start = 0
t_end = nd.edge_length

# get node inf
nd_name = nd.label
nd_age = nd_age_dict(nd_name)

# get integration interval info

# MBT
t_start = nd_age
t_end = nd.edge_length + nd_age
# simple BiSSE
# t_start = 0
# t_end = nd.edge_length

children_ds_es: ty.List[np.ndarray] = list()

Expand All @@ -30,21 +41,29 @@ def recursively_do_node(nd: dp.Node,
nd_age_dict,
pars,
n_states,
dn_es,
log_norm_factors)
children_ds_es.append(ch_ds_es)

n_ch_ds_es = len(children_ds_es)

######################
# Doing current node #
######################

# debugging
# print("doing nd", nd.label)

# new Ds and Es for every node, terminal or internal
ds_es = np.zeros(2 * n_states)

# if internal node
# (combine children if neither is a direct ancestor)
if n_ch_ds_es == 2:
if nd.num_child_nodes() == 2:
left_ds_es, right_ds_es = children_ds_es

# debugging
# print("ds_es (before combining)\n", ds_es)
# print("left_ds_es", left_ds_es)
# print("right_ds_es", right_ds_es)

for i in range(n_states):
# combine Ds
ds_es[i] = lambdas[i] * left_ds_es[i] * right_ds_es[i]
Expand All @@ -67,22 +86,41 @@ def recursively_do_node(nd: dp.Node,
else:
st = int(st)
# obs D is 1, Es remain zero
ds_es[nd.state] = 1

ds_es = \
pjbisse.solve_bisse_ds_es(ds_es,
ds_es[st] = 1

# debugging
# print("ds_es (after combining if internal)", ds_es)
# print("t_start", t_start, "t_end", t_end)

# if not root, we have a branch to integrate over
if nd.parent_node is not None:
# ds_es = \
# pjbisse.solve_bisse_ds_es(ds_es,
# t_start,
# t_end,
# qs,
# mus,
# lambdas,
# verbose=False)

ds_es = \
pjmbt.solve_mbt_ds_es(ds_es,
t_start,
t_end,
qs,
mus,
lambdas)
lambdas,
verbose=False)

norm_factor = np.sum(ds_es)

norm_factor = np.sum(ds_es)
# normalize ds_es
ds_es /= norm_factor

# normalize ds_es
ds_es /= norm_factor
log_norm_factors.append(math.log(norm_factor))

log_norm_factors.append(math.log(norm_factor))
# debugging
# print("returned ds_es", ds_es, "\n")

return ds_es, log_norm_factors

Expand All @@ -99,7 +137,7 @@ def prune_mbt(ann_tr: AnnotatedTree,

# initialize ds_es and log_norm_factors
ds_es = np.zeros(2 * n_states)
log_norm_factors: ty.List[float] = list
log_norm_factors: ty.List[float] = list()

# get origin or root node
seed_nd = ann_tr.origin_node if ann_tr.with_origin \
Expand All @@ -114,19 +152,26 @@ def prune_mbt(ann_tr: AnnotatedTree,
nd_age_dict,
pars,
n_states,
ds_es,
log_norm_factors)

# TODO: now do stuff at root depending on conditioning
log_lk = 0.0
lk = 0.0

# looking at Ds
for i in range(n_states):
log_lik += pi[i] * ds_es[i]
lk += pi[i] * ds_es[i]

log_lk = math.log(lk)

# put back normalization factors
# print("log-norm factor sum", sum(log_norm_factors))
log_lk += sum(log_norm_factors)

if cond_on_survival:
pass

return log_lk


if __name__ == "__main__":

Expand All @@ -136,7 +181,7 @@ def prune_mbt(ann_tr: AnnotatedTree,
# Initializing tree #
#####################
# '?' signalizes ambiguous state (represented as '-1' inside AnnotatedTree's dict members)
tr_str = "((A:1.0[&state=?],B:1.0[&state=?])nd1:1.0[&state=?],C:2.0[&state=?])root[&state=?];"
tr_str = "((A:0.0001[&state=1],B:0.0001[&state=0])nd1:0.0001[&state=?],C:0.0002[&state=0])root[&state=?];"
dp_tree = dp.Tree.get(data=tr_str, schema="newick")

# preparing DendroPy tree to be fed into AnnotatedTree
Expand Down Expand Up @@ -174,18 +219,24 @@ def prune_mbt(ann_tr: AnnotatedTree,
# Initializing parameters #
###########################

qs = np.matrix([[-1.1, 0.1], [0.1, -0.95]])
mus = np.array([.25, .35])
lambdas = np.array([.75, .5])
qs = np.matrix([[-2., .9], [.001, -.2]])
mus = np.array([.1, .1])
lambdas = np.array([1.0, .099])
pi = np.array([.5, .5])
pars = (qs, mus, lambdas, pi)

###################
# Get likelihood! #
###################

prune_mbt(ann_tr,
pars,
n_states,
cond_on_survival=True,
correct_tip_shuffling=False)
log_lk = \
prune_mbt(ann_tr,
pars,
n_states,
cond_on_survival=True,
correct_tip_shuffling=False)

print("Final log-lk =", log_lk)

# tr_str = "((A:0.0001[&state=1],B:0.0001[&state=0])nd1:0.0001[&state=?],C:0.0002[&state=0])root[&state=?];"
# log-lk = -10.00966707670628
22 changes: 20 additions & 2 deletions src/phylojunction/interface/pj_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ def run_example_manual_tree_building(ax: matplotlib.pyplot.Axes) -> None:
def build_tree() -> pjtr.AnnotatedTree:
origin_node = dp.Node(taxon=dp.Taxon(label="origin"), label="origin", edge_length=0.0)
origin_node.state = 0
origin_node.annotations.add_bound_attribute("state")
origin_node.index = 0
origin_node.annotations.add_bound_attribute("index")
origin_node.alive = False
origin_node.sampled = False
origin_node.is_sa = False
Expand All @@ -196,6 +199,9 @@ def build_tree() -> pjtr.AnnotatedTree:

dummy_node = dp.Node(taxon=dp.Taxon(label="dummy1"), label="dummy1", edge_length=1.0)
dummy_node.state = 0
dummy_node.annotations.add_bound_attribute("state")
dummy_node.index = 1
dummy_node.annotations.add_bound_attribute("index")
dummy_node.alive = False
dummy_node.sampled = False
dummy_node.is_sa = False
Expand All @@ -207,6 +213,9 @@ def build_tree() -> pjtr.AnnotatedTree:
# right child of dummy_node
sa_node = dp.Node(taxon=dp.Taxon(label="sa1"), label="sa1", edge_length=0.0)
sa_node.state = 0
sa_node.annotations.add_bound_attribute("state")
sa_node.index = 2
sa_node.annotations.add_bound_attribute("index")
sa_node.alive = False
sa_node.sampled = True
sa_node.is_sa = True
Expand All @@ -216,6 +225,9 @@ def build_tree() -> pjtr.AnnotatedTree:
# left child of dummy node
root_node = dp.Node(taxon=dp.Taxon(label="root"), label="root", edge_length=0.5)
root_node.state = 1
root_node.annotations.add_bound_attribute("state")
root_node.index = 3
root_node.annotations.add_bound_attribute("index")
root_node.alive = False
root_node.sampled = False
root_node.is_sa = False
Expand All @@ -228,6 +240,9 @@ def build_tree() -> pjtr.AnnotatedTree:
# left child of root node
extant_sp1 = dp.Node(taxon=dp.Taxon(label="sp1"), label="sp1", edge_length=0.25)
extant_sp1.state = 2
extant_sp1.annotations.add_bound_attribute("state")
extant_sp1.index = 4
extant_sp1.annotations.add_bound_attribute("index")
extant_sp1.alive = False
extant_sp1.sampled = False
extant_sp1.is_sa = False
Expand All @@ -237,6 +252,9 @@ def build_tree() -> pjtr.AnnotatedTree:
# right child of root node
extant_sp2 = dp.Node(taxon=dp.Taxon(label="sp2"), label="sp2", edge_length=0.5)
extant_sp2.state = 3
extant_sp2.annotations.add_bound_attribute("state")
extant_sp2.index = 5
extant_sp2.annotations.add_bound_attribute("index")
extant_sp2.alive = True
extant_sp2.sampled = True
extant_sp2.is_sa = False
Expand Down Expand Up @@ -448,9 +466,9 @@ def run_example_map_attr(ax: matplotlib.pyplot.Axes) -> None:
# example_to_run = 1
# example_to_run = 2
# example_to_run = 3
# example_to_run = 4
example_to_run = 4
# example_to_run = 5
example_to_run = 6
# example_to_run = 6

if example_to_run == 1:
dag_obj = run_example_yule_string()
Expand Down
Loading

0 comments on commit 73f45fc

Please sign in to comment.