diff --git a/src/emevo/analysis/tree.py b/src/emevo/analysis/tree.py index 7003672..db955fc 100644 --- a/src/emevo/analysis/tree.py +++ b/src/emevo/analysis/tree.py @@ -15,7 +15,6 @@ from numpy.typing import NDArray from pyarrow import Table - datafield = functools.partial(dataclasses.field, compare=False, hash=False, repr=False) @@ -165,8 +164,8 @@ def from_iter( iterator: Iterable[tuple[int, int] | tuple[int, int, dict]], root_idx: int = 0, ) -> Tree: - nodes = {} root = Node(index=root_idx, is_root=True) + nodes = {} for item in iterator: if len(item) == 2: @@ -191,6 +190,7 @@ def from_iter( root.add_child(node) node.sort_children() + nodes[root_idx] = root return Tree(root, nodes) @staticmethod @@ -300,7 +300,7 @@ def split(node: Node, threshold: int) -> int: return size size = split(self.root, min_group_size) - if size < min_group_size: + if size > 0: split_nodes[self.root.index] = SplitNode(size) for node_index in split_nodes: @@ -344,6 +344,17 @@ def _split_reward_mean( split_edges = set() reward_keys_t = tuple(reward_keys) + def find_group_root(node: Node) -> int: + ancestor_idx = node.index + ancestor = node.parent + while ( + ancestor is not None + and (ancestor.index, ancestor_idx) not in split_edges + ): + ancestor_idx = ancestor.index + ancestor = ancestor.parent + return ancestor_idx + def find_maxdiff_edge( frozen_split_edges: frozenset[tuple[int, int]] ) -> tuple[float, Edge]: @@ -352,8 +363,9 @@ def find_maxdiff_edge( for edge in self.all_edges(): if (edge.parent.index, edge.child.index) in split_edges: continue + parent_root = find_group_root(edge.parent) parent_size, parent_reward = compute_reward_mean( - edge.parent, + parent_root, skipped_edges=frozen_split_edges, reward_keys=reward_keys_t, ) diff --git a/tests/test_tree.py b/tests/test_tree.py index cac2c47..e49fb9a 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -21,6 +21,18 @@ def treedef() -> list[tuple[int, int]]: return [(1, 0), (4, 1), (3, 1), (5, 1), (9, 5), (8, 5), (2, 0), (6, 2), (7, 2)] +@pytest.fixture +def treedef_with_reward() -> list[tuple[int, int]]: + # 0 + # / \ + # 1 2 + # /|\ |\ + # 3 4 5 6 7 + # |\ + # 8 9 + return [(1, 0), (4, 1), (3, 1), (5, 1), (9, 5), (8, 5), (2, 0), (6, 2), (7, 2)] + + def test_from_iter(treedef: list[tuple[int, int]]) -> None: tree = Tree.from_iter(treedef, root_idx=-1) preorder = list(map(operator.attrgetter("index"), tree.traverse(preorder=True))) @@ -34,41 +46,31 @@ def test_split(treedef: list[tuple[int, int]]) -> None: tree = Tree.from_iter(treedef) sp1 = tree.split(min_group_size=3) assert len(sp1) == 4 - assert sp1[0] == 0 - for idx in [1, 3, 4]: - assert sp1[idx] == 1 - for idx in [2, 6, 7]: - assert sp1[idx] == 2 - for idx in [5, 8, 9]: - assert sp1[idx] == 3 + parents = [sn for sn in sp1.values() if sn.parent is None] + assert len(parents) == 1 + assert sp1[0] == parents[0] + assert sp1[0].size == 1 + assert list(sp1[0].children) == [1, 2] + assert sp1[1].size == 3 + assert sp1[2].size == 3 + assert sp1[5].size == 3 sp2 = tree.split(min_group_size=4) - assert len(sp2) == 4 - for idx in [0, 2, 6, 7]: - assert sp2[idx] == 0 - for idx in [1, 3, 4, 5, 8, 9]: - assert sp2[idx] == 1 - - -def test_multilabel_split(treedef: list[tuple[int, int]]) -> None: - tree = Tree.from_iter(treedef) - lb1 = tree.multilabel_split(min_group_size=3) - assert len(lb1) == 4 - assert list(sorted(lb1[0])) == [0] - assert list(sorted(lb1[1])) == [0, 1, 3, 4] - assert list(sorted(lb1[2])) == [0, 2, 6, 7] - assert list(sorted(lb1[3])) == [0, 1, 5, 8, 9] + assert len(sp2) == 2 + assert sp2[0].size == 4, sp2 + assert sp2[1].size == 6 def test_from_table() -> None: table = pq.read_table(ASSET_DIR.joinpath("profile_and_rewards.parquet")) tree = Tree.from_table(table, 20) for root in tree.root.children: - assert root.index < 10 + assert root.index <= 20 assert root.birth_time is not None for node in root.traverse(): assert node.birth_time is not None - data_dict = tree.as_datadict(split=10) - for key in ["unique_id", "label", "in-label-0", "in-label-1"]: + split = tree.split(min_group_size=10) + data_dict = tree.as_datadict(split) + for key in ["unique_id", "label"]: assert key in data_dict