Skip to content

Commit

Permalink
Fix tree tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Oct 11, 2024
1 parent 1b4d726 commit 835d1ec
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 29 deletions.
20 changes: 16 additions & 4 deletions src/emevo/analysis/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from numpy.typing import NDArray
from pyarrow import Table


datafield = functools.partial(dataclasses.field, compare=False, hash=False, repr=False)


Expand Down Expand Up @@ -165,8 +164,8 @@ def from_iter(
iterator: Iterable[tuple[int, int] | tuple[int, int, dict]],
root_idx: int = 0,
) -> Tree:
nodes = {}
root = Node(index=root_idx, is_root=True)
nodes = {}

for item in iterator:
if len(item) == 2:
Expand All @@ -191,6 +190,7 @@ def from_iter(
root.add_child(node)

node.sort_children()
nodes[root_idx] = root
return Tree(root, nodes)

@staticmethod
Expand Down Expand Up @@ -300,7 +300,7 @@ def split(node: Node, threshold: int) -> int:
return size

size = split(self.root, min_group_size)
if size < min_group_size:
if size > 0:
split_nodes[self.root.index] = SplitNode(size)

for node_index in split_nodes:
Expand Down Expand Up @@ -344,6 +344,17 @@ def _split_reward_mean(
split_edges = set()
reward_keys_t = tuple(reward_keys)

def find_group_root(node: Node) -> int:
ancestor_idx = node.index
ancestor = node.parent
while (
ancestor is not None
and (ancestor.index, ancestor_idx) not in split_edges
):
ancestor_idx = ancestor.index
ancestor = ancestor.parent
return ancestor_idx

def find_maxdiff_edge(
frozen_split_edges: frozenset[tuple[int, int]]
) -> tuple[float, Edge]:
Expand All @@ -352,8 +363,9 @@ def find_maxdiff_edge(
for edge in self.all_edges():
if (edge.parent.index, edge.child.index) in split_edges:
continue
parent_root = find_group_root(edge.parent)
parent_size, parent_reward = compute_reward_mean(
edge.parent,
parent_root,
skipped_edges=frozen_split_edges,
reward_keys=reward_keys_t,
)
Expand Down
52 changes: 27 additions & 25 deletions tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ def treedef() -> list[tuple[int, int]]:
return [(1, 0), (4, 1), (3, 1), (5, 1), (9, 5), (8, 5), (2, 0), (6, 2), (7, 2)]


@pytest.fixture
def treedef_with_reward() -> list[tuple[int, int]]:
# 0
# / \
# 1 2
# /|\ |\
# 3 4 5 6 7
# |\
# 8 9
return [(1, 0), (4, 1), (3, 1), (5, 1), (9, 5), (8, 5), (2, 0), (6, 2), (7, 2)]


def test_from_iter(treedef: list[tuple[int, int]]) -> None:
tree = Tree.from_iter(treedef, root_idx=-1)
preorder = list(map(operator.attrgetter("index"), tree.traverse(preorder=True)))
Expand All @@ -34,41 +46,31 @@ def test_split(treedef: list[tuple[int, int]]) -> None:
tree = Tree.from_iter(treedef)
sp1 = tree.split(min_group_size=3)
assert len(sp1) == 4
assert sp1[0] == 0
for idx in [1, 3, 4]:
assert sp1[idx] == 1
for idx in [2, 6, 7]:
assert sp1[idx] == 2
for idx in [5, 8, 9]:
assert sp1[idx] == 3
parents = [sn for sn in sp1.values() if sn.parent is None]
assert len(parents) == 1
assert sp1[0] == parents[0]
assert sp1[0].size == 1
assert list(sp1[0].children) == [1, 2]
assert sp1[1].size == 3
assert sp1[2].size == 3
assert sp1[5].size == 3

sp2 = tree.split(min_group_size=4)
assert len(sp2) == 4
for idx in [0, 2, 6, 7]:
assert sp2[idx] == 0
for idx in [1, 3, 4, 5, 8, 9]:
assert sp2[idx] == 1


def test_multilabel_split(treedef: list[tuple[int, int]]) -> None:
tree = Tree.from_iter(treedef)
lb1 = tree.multilabel_split(min_group_size=3)
assert len(lb1) == 4
assert list(sorted(lb1[0])) == [0]
assert list(sorted(lb1[1])) == [0, 1, 3, 4]
assert list(sorted(lb1[2])) == [0, 2, 6, 7]
assert list(sorted(lb1[3])) == [0, 1, 5, 8, 9]
assert len(sp2) == 2
assert sp2[0].size == 4, sp2
assert sp2[1].size == 6


def test_from_table() -> None:
table = pq.read_table(ASSET_DIR.joinpath("profile_and_rewards.parquet"))
tree = Tree.from_table(table, 20)
for root in tree.root.children:
assert root.index < 10
assert root.index <= 20
assert root.birth_time is not None
for node in root.traverse():
assert node.birth_time is not None

data_dict = tree.as_datadict(split=10)
for key in ["unique_id", "label", "in-label-0", "in-label-1"]:
split = tree.split(min_group_size=10)
data_dict = tree.as_datadict(split)
for key in ["unique_id", "label"]:
assert key in data_dict

0 comments on commit 835d1ec

Please sign in to comment.