diff --git a/src/phylojunction/data/tree.py b/src/phylojunction/data/tree.py index 717b0d1..76207da 100644 --- a/src/phylojunction/data/tree.py +++ b/src/phylojunction/data/tree.py @@ -226,6 +226,8 @@ class AnnotatedTree(dp.Tree): # dictionary of cladogenetic attribute transitions clado_at_dict: \ ty.Optional[ty.Dict[str, pjat.AttributeTransition]] # can be None + rec_tr_clado_at_dict: \ + ty.Optional[ty.Dict[str, pjat.AttributeTransition]] # can be None # to deal with effectively zero floats epsilon: float @@ -350,6 +352,7 @@ def __init__( self.at_dict = at_dict self.rec_tr_at_dict = copy.deepcopy(at_dict) self.clado_at_dict = clado_at_dict + self.rec_tr_clado_at_dict = copy.deepcopy(clado_at_dict) # node counting self.n_extant_terminal_nodes = 0 @@ -1115,13 +1118,17 @@ def _recur_find_extant_or_sa( return False def update_rec_tr_at_dict(self, rec_tree_root_nd: dp.Node) -> None: - """Update 'rec_tr_at_dict' member. + """Update 'rec_tr_at_dict' and 'rec_clado-at_dict' members. The 'at_dict' member of the AnnotatedTree, when defined, will by default host the state transitions of every node of the complete - tree. This method initializes 'rec_trat_dict' so that it reflects the - reconstructed tree -- it is only called when necessary, - by the 'extract_reconstructed_tree' method. + tree. This method initializes 'rec_tr_at_dict' so that it + reflects the reconstructed tree -- it is only called when + necessary, by the 'extract_reconstructed_tree' method. + + Member 'rec_tr_clado_at_dict' is also updated. Internal nodes + undergoing cladogenetic changes that are in the complete tree + but not in the reconstructed tree are removed. """ def recur_grabbing_int_nds_to_merge( @@ -1276,6 +1283,25 @@ def recur_grabbing_int_nds_to_merge( nd.label in self.rec_tr_at_dict: del self.rec_tr_at_dict[nd.label] + ############################## + # Updating rec_clado_at_dict # + ############################## + + if self.clado_at_dict is not None: + complete_tr_clado_at_dict_nd_name_set = \ + {nd_name for nd_name in self.clado_at_dict.keys()} + + rec_tr_int_nd_label_set = \ + {ind.label \ + for ind in self.tree_reconstructed.preorder_internal_node_iter()} + + in_complete_not_in_rec_set = \ + complete_tr_clado_at_dict_nd_name_set - rec_tr_int_nd_label_set + + for nd_name in in_complete_not_in_rec_set: + del self.rec_tr_clado_at_dict[nd_name] + + def update_rec_tr_sa_lineage_dict(self) -> None: """Update 'rec_tr_sa_lineage_dict' member. @@ -2569,8 +2595,6 @@ def _draw_clade(nd: dp.Node, y_top = y_coords[get_node_name(children[0])] y_bot = y_coords[get_node_name(children[1])] - print("nd", nd_name, "color", segment_colors) - # last color in segment_colors will # match the state of the node whose # subtending branch we are drawing diff --git a/src/phylojunction/data/tree.pyi b/src/phylojunction/data/tree.pyi index df84860..005338d 100644 --- a/src/phylojunction/data/tree.pyi +++ b/src/phylojunction/data/tree.pyi @@ -25,6 +25,7 @@ class AnnotatedTree(dp.Tree): origin_age: ty.Optional[float] origin_edge_length: float root_age: float + rec_tr_root_age: ty.Optional[float] node_heights_dict: ty.Dict[str, float] node_ages_dict: ty.Dict[str, float] slice_t_ends: ty.Optional[ty.List[float]] diff --git a/src/phylojunction/readwrite/pj_write.py b/src/phylojunction/readwrite/pj_write.py index 242ba61..7e039af 100644 --- a/src/phylojunction/readwrite/pj_write.py +++ b/src/phylojunction/readwrite/pj_write.py @@ -306,14 +306,17 @@ def recursively_collect_smaps(nd: dp.Node, it_idx_str = "1" sample_smap_str = "" for ann_tr_repl in node_val: - # TODO: prune repl and enter info in df - seed_nd = \ - ann_tr_repl.origin_node if ann_tr_repl.with_origin \ - else ann_tr_repl.root_node - node_ages_dict = ann_tr_repl.node_ages_dict - complete_tr_age = ann_tr_repl.seed_age + # method call updates ann_tr_repl's members rec_tr = ann_tr_repl.extract_reconstructed_tree() - at_dict = ann_tr_repl.at_dict + root_node = ann_tr_repl.rec_tr_root_node + root_age = ann_tr_repl.rec_tr_root_age + node_ages_dict = ann_tr_repl.rec_node_ages_dict + at_dict = ann_tr_repl.rec_tr_at_dict + + if rec_tr is None: + print(("Could not find a reconstructed tree. " + "Exiting...")) + clado_at_dict = ann_tr_repl.clado_at_dict # if we are done with a batch of replicates, we reset everything @@ -322,10 +325,11 @@ def recursively_collect_smaps(nd: dp.Node, it_idx_str = "1" sample_smap_str = "" - sample_smap_str = recursively_collect_smaps(seed_nd, - at_dict, + sample_smap_str = \ + recursively_collect_smaps(root_node, + at_dict, clado_at_dict, - complete_tr_age, + root_age, node_ages_dict, sample_smap_str, it_idx_str) diff --git a/tests/data/baseline_figs/test_harder_geosse_rec_tr.svg b/tests/data/baseline_figs/test_harder_geosse_rec_tr.svg new file mode 100644 index 0000000..7e33a9a --- /dev/null +++ b/tests/data/baseline_figs/test_harder_geosse_rec_tr.svg @@ -0,0 +1,500 @@ + + + + + + + + 2024-04-25T12:24:13.716645 + image/svg+xml + + + Matplotlib v3.8.2, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/test_tree_extract_reconstructed.py b/tests/data/test_tree_extract_reconstructed.py index fbc8963..23cd92c 100644 --- a/tests/data/test_tree_extract_reconstructed.py +++ b/tests/data/test_tree_extract_reconstructed.py @@ -1778,7 +1778,10 @@ def test_extract_reconstructed_tree_origin_one_root_side_dies_three_survive_sa_b self.assertEqual(tr_rec_str2, ";\n") def test_make_at_dict_reflect_rec_tree_origin_two_extinct(self) -> None: - """Test method that updates 'rec_tr_at_dict' for reconstructed tree.""" + """Test method that updates 'rec_tr_at_dict' for reconstructed tree. + + See also test_tree_plotting.test_harder_geosse_rec_tr() + """ origin_node = Node(taxon=Taxon(label="origin"), label="origin", edge_length=0.0) origin_node.state = 2 # AB @@ -1994,12 +1997,46 @@ def test_make_at_dict_reflect_rec_tree_origin_two_extinct(self) -> None: "sp5": [at4, at5] } + # internal_node2 + clado_at1 = pjat.AttributeTransition("state", + subtending_node_label="nd5", + global_time=3.0, + from_state=2, + to_state=0, + to_state2=1) + clado_at2 = pjat.AttributeTransition("state", + subtending_node_label="nd7", + global_time=3.0, + from_state=2, + to_state=0, + to_state2=1) + clado_at3 = pjat.AttributeTransition("state", + subtending_node_label="nd6", + global_time=2.0, + from_state=2, + to_state=0, + to_state2=1) + clado_at4 = pjat.AttributeTransition("state", + subtending_node_label="nd8", + global_time=2.0, + from_state=2, + to_state=0, + to_state2=1) + clado_at_dict = { + "nd5": [clado_at1], + "nd7": [clado_at2], + "nd6": [clado_at3], + "nd8": [clado_at4] + } + + total_state_count = 3 max_age = 5.0 ann_tr = pjtr.AnnotatedTree( tr_complete, total_state_count, at_dict=at_dict, + clado_at_dict=clado_at_dict, start_at_origin=True, max_age=max_age, epsilon=1e-12) @@ -2027,26 +2064,6 @@ def test_make_at_dict_reflect_rec_tree_origin_two_extinct(self) -> None: # suppress_rooting=True) # print(tr_rec_str) - # fig = matplotlib.pyplot.figure() - # ax = fig.add_axes([0.25, 0.2, 0.5, 0.6]) - # ax.patch.set_alpha(0.0) - # ax.xaxis.set_ticks([]) - # ax.yaxis.set_ticks([]) - # ax.spines['left'].set_visible(False) - # ax.spines['bottom'].set_visible(False) - # ax.spines['right'].set_visible(False) - # ax.spines['top'].set_visible(False) - - # plotting complete or rec tree - # draw_reconstructed = True - # pjtr.plot_ann_tree(ann_tr, - # ax, - # use_age=False, - # sa_along_branches=False, - # attr_of_interest="state", - # draw_reconstructed=draw_reconstructed) - # matplotlib.pyplot.show() - rec_tr_at_dict = ann_tr.rec_tr_at_dict # check that the right nodes are in the rec tree's at_dict @@ -2071,6 +2088,8 @@ def test_make_at_dict_reflect_rec_tree_origin_two_extinct(self) -> None: {"sp2": [1.0, 1.5, 2.0, 3.0], "sp5": [1.0, 1.5, 2.0, 3.0]}) + self.assertEqual({}, ann_tr.rec_tr_clado_at_dict) + if __name__ == "__main__": # Assuming you opened the PhyloJunction/ (repo root) folder diff --git a/tests/data/test_tree_plotting.py b/tests/data/test_tree_plotting.py index 9b9da14..e4cd89f 100644 --- a/tests/data/test_tree_plotting.py +++ b/tests/data/test_tree_plotting.py @@ -1102,6 +1102,7 @@ def test_sa_followed_by_sa_tip_fbd_rec_tr_plot(self) -> None: self.assertEqual(exp_y_coords, y_coords) def test_easy_geosse_rec_tr(self) -> None: + """Test plot for GeoSSE tree.""" origin_node = Node(taxon=Taxon(label="origin"), label="origin", edge_length=0.0) origin_node.state = 2 # AB @@ -1246,6 +1247,275 @@ def test_easy_geosse_rec_tr(self) -> None: self.assertEqual({'sp1': 2.0, 'sp3': 2.0, 'root': 0.0}, x_coords) self.assertEqual({'sp3': 3.0, 'sp1': 2.0, 'root': 2.5}, y_coords) + def test_harder_geosse_rec_tr(self) -> None: + """Test plot for GeoSSE tree with 4 extinct tips.""" + + origin_node = Node(taxon=Taxon(label="origin"), label="origin", edge_length=0.0) + origin_node.state = 2 # AB + origin_node.annotations.add_bound_attribute("state") + origin_node.index = 0 + origin_node.annotations.add_bound_attribute("index") + origin_node.alive = False + origin_node.sampled = False + origin_node.is_sa = False + origin_node.is_sa_dummy_parent = False + origin_node.is_sa_lineage = False + + root_node = Node(taxon=Taxon(label="root"), label="root", edge_length=1.0) + root_node.state = 2 # AB + root_node.annotations.add_bound_attribute("state") + root_node.index = 1 + root_node.annotations.add_bound_attribute("index") + root_node.alive = False + root_node.sampled = False + root_node.is_sa = False + root_node.is_sa_dummy_parent = False + root_node.is_sa_lineage = False + + # left child of root_node + internal_node1 = Node(taxon=Taxon(label="nd6"), + label="nd6", + edge_length=1.0) + internal_node1.state = 2 # AB + internal_node1.annotations.add_bound_attribute("state") + internal_node1.index = 2 + internal_node1.annotations.add_bound_attribute("index") + internal_node1.alive = False + internal_node1.sampled = False + internal_node1.is_sa = False + internal_node1.is_sa_dummy_parent = False + internal_node1.is_sa_lineage = False + + # right child of root_node + internal_node3 = Node(taxon=Taxon(label="nd8"), + label="nd8", + edge_length=1.0) + internal_node3.state = 2 # AB + internal_node3.annotations.add_bound_attribute("state") + internal_node3.index = 3 + internal_node3.annotations.add_bound_attribute("index") + internal_node3.alive = False + internal_node3.sampled = False + internal_node3.is_sa = False + internal_node3.is_sa_dummy_parent = False + internal_node3.is_sa_lineage = False + + # left child of internal_node1 + internal_node2 = Node(taxon=Taxon(label="nd5"), + label="nd5", + edge_length=1.0) + internal_node2.state = 2 # AB + internal_node2.annotations.add_bound_attribute("state") + internal_node2.index = 4 + internal_node2.annotations.add_bound_attribute("index") + internal_node2.alive = False + internal_node2.sampled = False + internal_node2.is_sa = False + internal_node2.is_sa_dummy_parent = False + internal_node2.is_sa_lineage = False + + # right child of internal_node1 + extinct_sp3 = Node(taxon=Taxon(label="sp3"), + label="sp3", + edge_length=2.0) + extinct_sp3.state = 0 # A + extinct_sp3.annotations.add_bound_attribute("state") + extinct_sp3.index = 6 + extinct_sp3.annotations.add_bound_attribute("index") + extinct_sp3.alive = False + extinct_sp3.sampled = False + extinct_sp3.is_sa = False + extinct_sp3.is_sa_dummy_parent = False + extinct_sp3.is_sa_lineage = False + + # left child of internal_node2 + extinct_sp1 = Node(taxon=Taxon(label="sp1"), + label="sp1", + edge_length=1.0) + extinct_sp1.state = 1 # B + extinct_sp1.annotations.add_bound_attribute("state") + extinct_sp1.index = 7 + extinct_sp1.annotations.add_bound_attribute("index") + extinct_sp1.alive = False + extinct_sp1.sampled = False + extinct_sp1.is_sa = False + extinct_sp1.is_sa_dummy_parent = False + extinct_sp1.is_sa_lineage = False + + # right child of internal_node2 + extant_sp2 = Node(taxon=Taxon(label="sp2"), + label="sp2", + edge_length=2.0) + extant_sp2.state = 2 # AB + extant_sp2.annotations.add_bound_attribute("state") + extant_sp2.index = 8 + extant_sp2.annotations.add_bound_attribute("index") + extant_sp2.alive = True + extant_sp2.sampled = True + extant_sp2.is_sa = False + extant_sp2.is_sa_dummy_parent = False + extant_sp2.is_sa_lineage = False + + # left child of internal_node1 + internal_node4 = Node(taxon=Taxon(label="nd7"), + label="nd7", + edge_length=1.0) + internal_node4.state = 2 # AB + internal_node4.annotations.add_bound_attribute("state") + internal_node4.index = 5 + internal_node4.annotations.add_bound_attribute("index") + internal_node4.alive = False + internal_node4.sampled = False + internal_node4.is_sa = False + internal_node4.is_sa_dummy_parent = False + internal_node4.is_sa_lineage = False + + # right child of internal_node1 + extinct_sp6 = Node(taxon=Taxon(label="sp6"), + label="sp6", + edge_length=2.0) + extinct_sp6.state = 0 # A + extinct_sp6.annotations.add_bound_attribute("state") + extinct_sp6.index = 9 + extinct_sp6.annotations.add_bound_attribute("index") + extinct_sp6.alive = False + extinct_sp6.sampled = False + extinct_sp6.is_sa = False + extinct_sp6.is_sa_dummy_parent = False + extinct_sp6.is_sa_lineage = False + + # left child of internal_node4 + extinct_sp4 = Node(taxon=Taxon(label="sp4"), + label="sp4", + edge_length=1.0) + extinct_sp4.state = 1 # B + extinct_sp4.annotations.add_bound_attribute("state") + extinct_sp4.index = 10 + extinct_sp4.annotations.add_bound_attribute("index") + extinct_sp4.alive = False + extinct_sp4.sampled = False + extinct_sp4.is_sa = False + extinct_sp4.is_sa_dummy_parent = False + extinct_sp4.is_sa_lineage = False + + # right child of internal_node4 + extant_sp5 = Node(taxon=Taxon(label="sp5"), + label="sp5", + edge_length=2.0) + extant_sp5.state = 2 # AB + extant_sp5.annotations.add_bound_attribute("state") + extant_sp5.index = 11 + extant_sp5.annotations.add_bound_attribute("index") + extant_sp5.alive = True + extant_sp5.sampled = True + extant_sp5.is_sa = False + extant_sp5.is_sa_dummy_parent = False + extant_sp5.is_sa_lineage = False + + # building topology + internal_node2.add_child(extinct_sp1) + internal_node2.add_child(extant_sp2) + + internal_node1.add_child(internal_node2) # 'nd5' + internal_node1.add_child(extinct_sp3) + + internal_node4.add_child(extinct_sp4) + internal_node4.add_child(extant_sp5) + + internal_node3.add_child(internal_node4) # 'nd7' + internal_node3.add_child(extinct_sp6) + + root_node.add_child(internal_node1) # 'nd6' + root_node.add_child(internal_node3) # 'nd8' + + origin_node.add_child(root_node) + + # wrapping up tree + tr_complete = Tree(seed_node=origin_node) + tr_complete.taxon_namespace.add_taxon(origin_node.taxon) + tr_complete.taxon_namespace.add_taxon(root_node.taxon) + tr_complete.taxon_namespace.add_taxon(internal_node1.taxon) + tr_complete.taxon_namespace.add_taxon(internal_node2.taxon) + tr_complete.taxon_namespace.add_taxon(internal_node3.taxon) + tr_complete.taxon_namespace.add_taxon(internal_node4.taxon) + tr_complete.taxon_namespace.add_taxon(extinct_sp3.taxon) + tr_complete.taxon_namespace.add_taxon(extinct_sp1.taxon) + tr_complete.taxon_namespace.add_taxon(extinct_sp4.taxon) + tr_complete.taxon_namespace.add_taxon(extinct_sp6.taxon) + tr_complete.taxon_namespace.add_taxon(extant_sp2.taxon) + tr_complete.taxon_namespace.add_taxon(extant_sp5.taxon) + + at1_1_1 = pjat.AttributeTransition("state", "nd5", 2.0, 2, 1) + at1_1_2 = pjat.AttributeTransition("state", "nd5", 2.5, 1, 2) + at1_1_3 = pjat.AttributeTransition("state", "nd7", 2.0, 2, 1) + at1_1_4 = pjat.AttributeTransition("state", "nd7", 2.5, 1, 2) + at1_2_1 = pjat.AttributeTransition("state", "sp3", 2.0, 2, 0) + at1_2_2 = pjat.AttributeTransition("state", "sp6", 2.0, 2, 0) + at2 = pjat.AttributeTransition("state", "sp2", 3.0, 2, 0) + at3 = pjat.AttributeTransition("state", "sp2", 4.0, 0, 2) + at4 = pjat.AttributeTransition("state", "sp5", 3.0, 2, 0) + at5 = pjat.AttributeTransition("state", "sp5", 4.0, 0, 2) + at_dict = { + "nd5": [at1_1_1, at1_1_2], + "nd7": [at1_1_3, at1_1_4], + "sp3": [at1_2_1], + "sp6": [at1_2_2], + "sp2": [at2, at3], + "sp5": [at4, at5] + } + + total_state_count = 3 + max_age = 5.0 + ann_tr = pjtr.AnnotatedTree( + tr_complete, + total_state_count, + at_dict=at_dict, + start_at_origin=True, + max_age=max_age, + epsilon=1e-12) + + ann_tr.populate_nd_attr_dict("state") + + # ann_tr_str = \ + # ann_tr.tree.as_string( + # schema="newick", + # suppress_internal_taxon_labels=True, + # suppress_internal_node_labels=False) + # print(ann_tr_str) + + # updates rec tree-related members in 'ann_tr' + tr_rec = \ + ann_tr.extract_reconstructed_tree( + plotting_overhead=True, + require_obs_both_sides=False) + + fig = matplotlib.pyplot.figure() + ax = fig.add_axes([0.25, 0.2, 0.5, 0.6]) + ax.patch.set_alpha(0.0) + ax.xaxis.set_ticks([]) + ax.yaxis.set_ticks([]) + ax.spines['left'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + + # plotting complete or rec tree + draw_reconstructed = True + x_coords, y_coords = \ + pjtr.plot_ann_tree(ann_tr, + ax, + use_age=False, + sa_along_branches=False, + attr_of_interest="state", + draw_reconstructed=draw_reconstructed) + # matplotlib.pyplot.show() # to see it (compare to baseline figs!) + # new_svg_path = "baseline_figs/test_harder_geosse_rec_tr.svg" + # matplotlib.pyplot.savefig(new_svg_path) + + self.assertEqual({'sp2': 4.0, 'sp5': 4.0, 'root': 0.0}, + x_coords) + self.assertEqual({'sp5': 3.0, 'sp2': 2.0, 'root': 2.5}, + y_coords) if __name__ == "__main__":