Skip to content

Commit

Permalink
Add rec tr clado_at_dict update code
Browse files Browse the repository at this point in the history
  • Loading branch information
binho authored and binho committed Apr 25, 2024
1 parent aba795a commit ad275c6
Show file tree
Hide file tree
Showing 6 changed files with 855 additions and 37 deletions.
36 changes: 30 additions & 6 deletions src/phylojunction/data/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ class AnnotatedTree(dp.Tree):
# dictionary of cladogenetic attribute transitions
clado_at_dict: \
ty.Optional[ty.Dict[str, pjat.AttributeTransition]] # can be None
rec_tr_clado_at_dict: \
ty.Optional[ty.Dict[str, pjat.AttributeTransition]] # can be None

# to deal with effectively zero floats
epsilon: float
Expand Down Expand Up @@ -350,6 +352,7 @@ def __init__(
self.at_dict = at_dict
self.rec_tr_at_dict = copy.deepcopy(at_dict)
self.clado_at_dict = clado_at_dict
self.rec_tr_clado_at_dict = copy.deepcopy(clado_at_dict)

# node counting
self.n_extant_terminal_nodes = 0
Expand Down Expand Up @@ -1115,13 +1118,17 @@ def _recur_find_extant_or_sa(
return False

def update_rec_tr_at_dict(self, rec_tree_root_nd: dp.Node) -> None:
"""Update 'rec_tr_at_dict' member.
"""Update 'rec_tr_at_dict' and 'rec_clado-at_dict' members.
The 'at_dict' member of the AnnotatedTree, when defined, will by
default host the state transitions of every node of the complete
tree. This method initializes 'rec_trat_dict' so that it reflects the
reconstructed tree -- it is only called when necessary,
by the 'extract_reconstructed_tree' method.
tree. This method initializes 'rec_tr_at_dict' so that it
reflects the reconstructed tree -- it is only called when
necessary, by the 'extract_reconstructed_tree' method.
Member 'rec_tr_clado_at_dict' is also updated. Internal nodes
undergoing cladogenetic changes that are in the complete tree
but not in the reconstructed tree are removed.
"""

def recur_grabbing_int_nds_to_merge(
Expand Down Expand Up @@ -1276,6 +1283,25 @@ def recur_grabbing_int_nds_to_merge(
nd.label in self.rec_tr_at_dict:
del self.rec_tr_at_dict[nd.label]

##############################
# Updating rec_clado_at_dict #
##############################

if self.clado_at_dict is not None:
complete_tr_clado_at_dict_nd_name_set = \
{nd_name for nd_name in self.clado_at_dict.keys()}

rec_tr_int_nd_label_set = \
{ind.label \
for ind in self.tree_reconstructed.preorder_internal_node_iter()}

in_complete_not_in_rec_set = \
complete_tr_clado_at_dict_nd_name_set - rec_tr_int_nd_label_set

for nd_name in in_complete_not_in_rec_set:
del self.rec_tr_clado_at_dict[nd_name]


def update_rec_tr_sa_lineage_dict(self) -> None:
"""Update 'rec_tr_sa_lineage_dict' member.
Expand Down Expand Up @@ -2569,8 +2595,6 @@ def _draw_clade(nd: dp.Node,
y_top = y_coords[get_node_name(children[0])]
y_bot = y_coords[get_node_name(children[1])]

print("nd", nd_name, "color", segment_colors)

# last color in segment_colors will
# match the state of the node whose
# subtending branch we are drawing
Expand Down
1 change: 1 addition & 0 deletions src/phylojunction/data/tree.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class AnnotatedTree(dp.Tree):
origin_age: ty.Optional[float]
origin_edge_length: float
root_age: float
rec_tr_root_age: ty.Optional[float]
node_heights_dict: ty.Dict[str, float]
node_ages_dict: ty.Dict[str, float]
slice_t_ends: ty.Optional[ty.List[float]]
Expand Down
24 changes: 14 additions & 10 deletions src/phylojunction/readwrite/pj_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,17 @@ def recursively_collect_smaps(nd: dp.Node,
it_idx_str = "1"
sample_smap_str = ""
for ann_tr_repl in node_val:
# TODO: prune repl and enter info in df
seed_nd = \
ann_tr_repl.origin_node if ann_tr_repl.with_origin \
else ann_tr_repl.root_node
node_ages_dict = ann_tr_repl.node_ages_dict
complete_tr_age = ann_tr_repl.seed_age
# method call updates ann_tr_repl's members
rec_tr = ann_tr_repl.extract_reconstructed_tree()
at_dict = ann_tr_repl.at_dict
root_node = ann_tr_repl.rec_tr_root_node
root_age = ann_tr_repl.rec_tr_root_age
node_ages_dict = ann_tr_repl.rec_node_ages_dict
at_dict = ann_tr_repl.rec_tr_at_dict

if rec_tr is None:
print(("Could not find a reconstructed tree. "
"Exiting..."))

clado_at_dict = ann_tr_repl.clado_at_dict

# if we are done with a batch of replicates, we reset everything
Expand All @@ -322,10 +325,11 @@ def recursively_collect_smaps(nd: dp.Node,
it_idx_str = "1"
sample_smap_str = ""

sample_smap_str = recursively_collect_smaps(seed_nd,
at_dict,
sample_smap_str = \
recursively_collect_smaps(root_node,
at_dict,
clado_at_dict,
complete_tr_age,
root_age,
node_ages_dict,
sample_smap_str,
it_idx_str)
Expand Down
Loading

0 comments on commit ad275c6

Please sign in to comment.