-
Notifications
You must be signed in to change notification settings - Fork 0
/
mcts_node.py
160 lines (121 loc) · 4.65 KB
/
mcts_node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import random
import time
from connect4 import Board
import math
class MonteCarloNode:
def __init__(self, state : Board, parent=None, parent_action=None):
self.state = state
self.parent = parent
self.parent_action = parent_action
self.total_visits = 0
self.total_rewards = 0
self.children = []
if self.state.is_terminal():
self.missing_child_actions = []
else:
self.missing_child_actions = self.state.get_actions()
def is_fully_expanded(self):
"""
Determines if the node is fully expanded.
"""
return len(self.missing_child_actions) == 0
def get_average_reward(self):
if self.total_visits == 0:
return 0
return self.total_rewards / self.total_visits
def get_best_average_child(self):
"""
Returns the child node with the best average reward.
"""
if self.state.actor() == 0:
node = max(self.children, key = lambda x: x.get_average_reward())
else:
node = min(self.children, key = lambda x: x.get_average_reward())
return node
def get_ucb_value(self, parent_total_visits, parent_actor):
"""
Returns the UCB value of the node using parents visits and actor
"""
if parent_actor == 0:
value = self.get_average_reward() + math.sqrt(2 * math.log(parent_total_visits) / self.total_visits)
else:
value = self.get_average_reward() - math.sqrt(2 * math.log(parent_total_visits) / self.total_visits)
return value
def get_best_ucb_child(self):
"""
Returns the child node with the best UCB value.
"""
if self.state.actor() == 0:
node = max(self.children, key = lambda x: x.get_ucb_value(self.total_visits, self.state.actor()))
else:
node = min(self.children, key = lambda x: x.get_ucb_value(self.total_visits, self.state.actor()))
return node
def expand(self):
"""
Expands the node by adding a new child node from an unexplored action.
"""
action = self.missing_child_actions.pop()
next_state = self.state.successor(action)
child_node = MonteCarloNode(next_state, parent=self, parent_action=action)
self.children.append(child_node)
return child_node
def find_leaf_node(self):
"""
Returns the leaf node of the tree using UCB values.
"""
current_node = self
while not current_node.state.is_terminal():
if not current_node.is_fully_expanded():
return current_node.expand()
else:
current_node = current_node.get_best_ucb_child()
return current_node
def simulate(self):
"""
Returns the terminal value of the node by randomly simulating game.
"""
state = self.state
while not state.is_terminal():
state = state.successor(random.choice(state.get_actions()))
return state.payoff()
def update_rewards(self, reward):
"""
Updates the total reward and total visits of the node and all its parents.
"""
parent_node = self
while parent_node is not None:
parent_node.total_rewards += reward
parent_node.total_visits += 1
parent_node = parent_node.parent
def __str__(self) -> str:
ucb_value = "N/A"
if self.parent:
ucb_value = self.get_ucb_value(self.parent.total_visits, self.parent.state.actor())
string_form = f"""
Node: {self.state}
Total Visits: {self.total_visits}
Total Rewards: {self.total_rewards}
Average Reward: {self.get_average_reward()}
UCB Value: {ucb_value}
Parent Action: {self.parent_action}
Actor: {self.state.actor()}
Missing Child Actions: {self.missing_child_actions}
Children: {len(self.children)}
"""
return string_form
def mcts_policy(time_duration):
def fxn(initial_position: Board):
start_time = time.time()
root = MonteCarloNode(initial_position, None)
# Learning while time remains
while time.time() - start_time < time_duration:
# Gets the leaf node in the UCB tree
node = root.find_leaf_node()
# Determines the random terminal value of the node
reward = node.simulate()
# Updates the rewards of the parents
node.update_rewards(reward)
node = root.get_best_average_child()
# print(root, node)
return node.parent_action
return fxn