From b002c0e1f3a85f59e4870ac7a0b201e87d24b32d Mon Sep 17 00:00:00 2001 From: kngwyu Date: Fri, 18 Oct 2024 11:54:19 +0900 Subject: [PATCH] Fix root problem in tree splitting --- Makefile | 2 +- src/emevo/analysis/tree.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 7519d3f..780df93 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,7 @@ register: jupyter: - uv run --with jupyter jupyter lab --port=9998 --no-browser + uv run --with jupyter --with jupyterlab_code_formatter --with black --with isort jupyter lab --port=9998 --no-browser sync: diff --git a/src/emevo/analysis/tree.py b/src/emevo/analysis/tree.py index 8a371cd..6d9aa29 100644 --- a/src/emevo/analysis/tree.py +++ b/src/emevo/analysis/tree.py @@ -374,12 +374,19 @@ def find_maxdiff_edge( ) -> tuple[float, Edge]: max_effect = 0.0 max_effect_edge = None + failure_causes = { + "Edge already used": 0, + "Group size is too small": 0, + "Effect is too small": 0, + } for edge in self.all_edges(): if (edge.parent.index, edge.child.index) in split_edges: + failure_causes["Edge already used"] += 1 continue parent_root = self.nodes[find_group_root(edge.parent)] parent_size, parent_reward = compute_reward_mean( parent_root, + is_root=parent_root.index == self.root.index, skipped_edges=frozen_split_edges, reward_keys=reward_keys_t, ) @@ -392,6 +399,7 @@ def find_maxdiff_edge( child_size < min_group_size or (parent_size - child_size) < min_group_size ): + failure_causes["Group size is too small"] += 1 continue assert parent_size > child_size, (parent_size, child_size) split_size = parent_size - child_size @@ -405,7 +413,9 @@ def find_maxdiff_edge( if effect > max_effect: max_effect = effect max_effect_edge = edge - assert max_effect_edge is not None, "Couldn't find maxdiff_edge anymore" + else: + failure_causes["Effect is too small"] += 1 + assert max_effect_edge is not None, f"Couldn't find maxdiff_edge anymore (Reason: {failure_causes})" return max_effect, max_effect_edge for _ in range(n_trial): @@ -414,6 +424,7 @@ def find_maxdiff_edge( parent_root = self.nodes[find_group_root(edge.parent)] parent_size, parent_reward = compute_reward_mean( parent_root, + is_root=parent_root.index == self.root.index, skipped_edges=frozen_split_edges, reward_keys=reward_keys_t, )