-
Notifications
You must be signed in to change notification settings - Fork 1
/
mcts_node.py
49 lines (40 loc) · 1.84 KB
/
mcts_node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import math
from thoughtsculpt.thought_evaluator import ThoughtEvaluator
from thoughtsculpt.thought_generator import ThoughtGenerator
class MCTSNode:
def __init__(self, solution, parent=None):
self.solution = solution
self.parent = parent
self.children = []
self.num_visits = 0
self.total_reward = 0.0
def is_terminal(self):
# Implement logic to determine if the node is a terminal node
pass
def is_fully_expanded(self):
# Implement logic to determine if the node is fully expanded
pass
def expand(self, api_key, task_description):
thought_generator = ThoughtGenerator(api_key)
feedback_text, _ = ThoughtEvaluator(api_key).evaluate(self.solution, task_description)
candidate_solution = thought_generator.generate(self.solution, feedback_text, task_description)[0]
child = MCTSNode(candidate_solution, self)
self.children.append(child)
return child
def select_child(self):
ucb1_scores = [child.ucb1_score(self.num_visits) for child in self.children]
return self.children[ucb1_scores.index(max(ucb1_scores))]
def ucb1_score(self, parent_visits):
exploration_factor = math.sqrt(2 * math.log(parent_visits) / self.num_visits)
return (self.total_reward / self.num_visits) + exploration_factor
def backpropagate(self, reward):
self.num_visits += 1
self.total_reward += reward
if self.parent:
self.parent.backpropagate(reward)
def best_child(self):
return max(self.children, key=lambda child: child.total_reward / child.num_visits)
def evaluate(self, api_key, task_description):
thought_evaluator = ThoughtEvaluator(api_key)
_, feedback_score = thought_evaluator.evaluate(self.solution, task_description)
return feedback_score