Skip to content

Commit

Permalink
Fix root problem in tree splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Oct 18, 2024
1 parent c2872a6 commit b002c0e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ register:


jupyter:
uv run --with jupyter jupyter lab --port=9998 --no-browser
uv run --with jupyter --with jupyterlab_code_formatter --with black --with isort jupyter lab --port=9998 --no-browser


sync:
Expand Down
13 changes: 12 additions & 1 deletion src/emevo/analysis/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,12 +374,19 @@ def find_maxdiff_edge(
) -> tuple[float, Edge]:
max_effect = 0.0
max_effect_edge = None
failure_causes = {
"Edge already used": 0,
"Group size is too small": 0,
"Effect is too small": 0,
}
for edge in self.all_edges():
if (edge.parent.index, edge.child.index) in split_edges:
failure_causes["Edge already used"] += 1
continue
parent_root = self.nodes[find_group_root(edge.parent)]
parent_size, parent_reward = compute_reward_mean(
parent_root,
is_root=parent_root.index == self.root.index,
skipped_edges=frozen_split_edges,
reward_keys=reward_keys_t,
)
Expand All @@ -392,6 +399,7 @@ def find_maxdiff_edge(
child_size < min_group_size
or (parent_size - child_size) < min_group_size
):
failure_causes["Group size is too small"] += 1
continue
assert parent_size > child_size, (parent_size, child_size)
split_size = parent_size - child_size
Expand All @@ -405,7 +413,9 @@ def find_maxdiff_edge(
if effect > max_effect:
max_effect = effect
max_effect_edge = edge
assert max_effect_edge is not None, "Couldn't find maxdiff_edge anymore"
else:
failure_causes["Effect is too small"] += 1
assert max_effect_edge is not None, f"Couldn't find maxdiff_edge anymore (Reason: {failure_causes})"
return max_effect, max_effect_edge

for _ in range(n_trial):
Expand All @@ -414,6 +424,7 @@ def find_maxdiff_edge(
parent_root = self.nodes[find_group_root(edge.parent)]
parent_size, parent_reward = compute_reward_mean(
parent_root,
is_root=parent_root.index == self.root.index,
skipped_edges=frozen_split_edges,
reward_keys=reward_keys_t,
)
Expand Down

0 comments on commit b002c0e

Please sign in to comment.