Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Oct 2, 2024
2 parents e863053 + f0a7b6f commit ed1624f
Showing 1 changed file with 129 additions and 84 deletions.
213 changes: 129 additions & 84 deletions src/emevo/analysis/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dataclasses
import functools
from collections.abc import Iterable, Sequence
from typing import Any
from typing import Any, Callable
from weakref import ReferenceType
from weakref import ref as make_weakref

Expand All @@ -17,7 +17,6 @@


datafield = functools.partial(dataclasses.field, compare=False, hash=False, repr=False)
_ROOT_INDEX = -1


@functools.total_ordering
Expand Down Expand Up @@ -116,7 +115,43 @@ def __lt__(self, other: Edge) -> bool:
class SplitNode:
size: int
reward_mean: dict[str, float] | None = None
children: list[int] = dataclasses.field(default_factory=list)
children: set[int] = dataclasses.field(default_factory=set)


@functools.cache
def compute_reward_mean(
node: Node,
is_root: bool = False,
skipped_edges: frozenset[tuple[int, int]] | None = None,
reward_keys: tuple[str, ...] = (),
) -> tuple[int, dict[str, float]]:
if is_root:
size_list = [0]
reward_mean_lists = {key: [0.0] for key in reward_keys}
else:
if reward_keys[0] not in node.info:
return 0, {key: 0.0 for key in reward_keys}
size_list = [1]
reward_mean_lists = {key: [node.info[key]] for key in reward_keys}

for child in node.children:
if skipped_edges is not None and (node.index, child.index) in skipped_edges:
continue
n_children, reward_mean = compute_reward_mean(
child,
skipped_edges=skipped_edges,
reward_keys=reward_keys,
)
size_list.append(n_children)
for key, rmean in reward_mean.items():
reward_mean_lists[key].append(rmean)

total_size = np.sum(size_list)
rmean_dict = {}
for key, rmean in reward_mean_lists.items():
rsum = np.sum([nc * rm for nc, rm in zip(size_list, rmean)])
rmean_dict[key] = rsum / total_size
return total_size, rmean_dict


@dataclasses.dataclass
Expand All @@ -126,10 +161,11 @@ class Tree:

@staticmethod
def from_iter(
iterator: Iterable[tuple[int, int] | tuple[int, int, dict]], root_idx: int = 0
iterator: Iterable[tuple[int, int] | tuple[int, int, dict]],
root_idx: int = 0,
) -> Tree:
nodes = {}
root = Node(index=_ROOT_INDEX, is_root=True)
root = Node(index=root_idx, is_root=True)

for item in iterator:
if len(item) == 2:
Expand Down Expand Up @@ -160,7 +196,7 @@ def from_iter(
def from_table(
table: Table,
initial_population: int | None = None,
root_idx: int = -1,
root_idx: int = 0,
) -> Tree:
birth_steps = {}

Expand Down Expand Up @@ -217,7 +253,7 @@ def split(
reward_keys: list[str] | None = None,
) -> dict[int, SplitNode]:
if method == "greedy":
split_nodes = self._split_greedy(min_group_size)
split_nodes = self._split_greedy(min_group_size, reward_keys)
elif method == "reward":
split_nodes = self._split_reward_mean(min_group_size, n_trial, reward_keys)
else:
Expand All @@ -238,38 +274,61 @@ def colorize_impl(node: Node, color: int) -> None:
colorize_impl(self.nodes[node_idx], i)
return categ

def _split_greedy(self, min_group_size) -> dict[int, SplitNode]:
def _split_greedy(
self,
min_group_size: int,
reward_keys: list[str] | None,
) -> dict[int, SplitNode]:
split_nodes = {}
split_edges = set()

def split(node: Node, threshold: int) -> int:
size = 0
size = 1
for child in node.children:
# Number of children that are not splitted
n_existing_children = split(child, threshold)
size += n_existing_children

if size >= threshold:
parent = node.parent
split_nodes[node.index] = SplitNode(size)
if parent is not None:
split_edges.add((parent.index, node.index))
return 0
else:
return size

def find_children(node: Node) -> list[int]:
children = []
for child in node.children:
children += find_children(child)
size = split(self.root, min_group_size)
if size < min_group_size:
split_nodes[self.root.index] = SplitNode(size)

if node in split_nodes:
split_nodes[node.index].children = children
return list[node.index]
else:
return children
for node_index, split_node in split_nodes.items():
if node_index == self.root.index:
continue
# Find Parent
ancestor = self.nodes[node_index].parent
while ancestor is not None:
if ancestor.index in split_nodes:
split_nodes[ancestor.index].children.add(node_index)
break
ancestor = ancestor.parent

for root in self.root.children:
size = split(root, min_group_size)
if size >= min_group_size:
split_nodes[root.index] = SplitNode(size)
find_children(root)
if reward_keys is not None:
reward_keys_t = tuple(reward_keys)
frozen_split_edges = frozenset(split_edges)
for node_index, split_node in split_nodes.items():
node = (
self.root
if node_index == self.root.index
else self.nodes[node_index]
)
size, reward = compute_reward_mean(
node,
skipped_edges=frozen_split_edges,
reward_keys=reward_keys_t,
)
split_node.size = size
split_node.reward_mean = reward

return split_nodes

Expand All @@ -281,99 +340,85 @@ def _split_reward_mean(
) -> dict[int, SplitNode]:
split_nodes = {}
split_edges = set()
reward_keys_t = tuple(reward_keys)

@functools.cache
def compute_reward_mean(
node: Node,
n_split: int = 0,
is_root: bool = False,
) -> tuple[int, dict[str, float]]:
if is_root:
size_list = [0]
reward_mean_lists = {key: [0.0] for key in reward_keys}
else:
if reward_keys[0] not in node.info:
return 0, {key: 0.0 for key in reward_keys}
size_list = [1]
reward_mean_lists = {key: [node.info[key]] for key in reward_keys}

for child in node.children:
if (node.index, child.index) in split_edges:
continue
n_children, reward_mean = compute_reward_mean(child, n_split=n_split)
size_list.append(n_children)
for key, rmean in reward_mean.items():
reward_mean_lists[key].append(rmean)

total_size = np.sum(size_list)
rmean_dict = {}
for key, rmean in reward_mean_lists.items():
rsum = np.sum([nc * rm for nc, rm in zip(size_list, rmean)])
rmean_dict[key] = rsum / total_size
return total_size, rmean_dict

def find_maxdiff_edge(n_split: int, min_group_size: int) -> tuple[float, Edge]:
def find_maxdiff_edge(
frozen_split_edges: frozenset[tuple[int, int]]
) -> tuple[float, Edge]:
max_effect = 0.0
max_effect_edge = None
for edge in self.all_edges():
if (edge.parent.index, edge.child.index) in split_edges:
continue
parent_size, parent_reward = compute_reward_mean(
edge.parent,
n_split=n_split,
skipped_edges=frozen_split_edges,
reward_keys=reward_keys_t,
)
child_size, child_reward = compute_reward_mean(
edge.child,
n_split=n_split,
skipped_edges=frozen_split_edges,
reward_keys=reward_keys_t,
)
if (
child_size < min_group_size
or (parent_size - child_size) < min_group_size
):
continue
assert parent_size > child_size, (parent_size, child_size)
split_size = parent_size - child_size
total_diff = 0.0
for key in reward_keys:
parent_rew_total = parent_reward[key] * parent_size
child_rew_total = child_reward[key] * child_size
split_rew = (parent_rew_total - child_rew_total) / split_size
total_diff += (child_reward[key] - split_rew) ** 2
effect = total_diff**0.5
if effect > max_effect:
max_effect = effect
max_effect_edge = edge
assert parent_size > child_size, (parent_size, child_size)
split_size = parent_size - child_size
total_diff = 0.0
for key in reward_keys:
parent_rew_total = parent_reward[key] * parent_size
child_rew_total = child_reward[key] * child_size
split_rew = (parent_rew_total - child_rew_total) / split_size
total_diff += (child_reward[key] - split_rew) ** 2
effect = total_diff**0.5
if effect > max_effect:
max_effect = effect
max_effect_edge = edge
assert max_effect_edge is not None, "Couldn't find maxdiff_edge anymore"
return max_effect, max_effect_edge

for i in range(n_trial):
maxe, edge = find_maxdiff_edge(i, min_group_size)
parent_size, parent_reward = compute_reward_mean(edge.parent, n_split=i)
child_size, child_reward = compute_reward_mean(edge.child, n_split=i)
size_new = parent_size - child_size
rew_new = {}
frozen_split_edges = frozenset(split_edges)
maxe, edge = find_maxdiff_edge(frozen_split_edges)
parent_size, parent_reward = compute_reward_mean(
edge.parent,
skipped_edges=frozen_split_edges,
reward_keys=reward_keys_t,
)
child_size, child_reward = compute_reward_mean(
edge.child,
skipped_edges=frozen_split_edges,
reward_keys=reward_keys_t,
)
split_size = parent_size - child_size
assert split_size > 0, (parent_size, child_size, edge)
split_rew = {}
for key in parent_reward:
rew_new[key] = (
parent_reward[key] * parent_size - child_reward[key] * child_size
) / size_new
parent_rew_total = parent_reward[key] * parent_size
child_rew_total = child_reward[key] * child_size
split_rew[key] = (parent_rew_total - child_rew_total) / split_size
# Make nodes
if edge.parent.index in split_nodes:
# Add child
split_nodes[edge.parent.index].size = size_new
split_nodes[edge.parent.index].reward_mean = rew_new
split_nodes[edge.parent.index].children.append(edge.child.index)
split_nodes[edge.parent.index].size = split_size
split_nodes[edge.parent.index].reward_mean = split_rew
split_nodes[edge.parent.index].children.add(edge.child.index)
else:
split_nodes[edge.parent.index] = SplitNode(
size_new,
rew_new,
children=[edge.child.index],
split_size,
split_rew,
children=set([edge.child.index]),
)
# Find Parent
ancestor = edge.parent.parent
while ancestor is not None:
if ancestor.index in split_nodes:
if edge.parent.index not in split_nodes[ancestor.index].children:
split_nodes[ancestor.index].children.append(edge.parent.index)
split_nodes[ancestor.index].size -= size_new
split_nodes[ancestor.index].children.add(edge.parent.index)
split_nodes[ancestor.index].size -= split_size
break
ancestor = ancestor.parent
split_nodes[edge.child.index] = SplitNode(child_size, child_reward)
Expand Down

0 comments on commit ed1624f

Please sign in to comment.