Skip to content

Commit

Permalink
feat: used rayon to parallelize child rollouts
Browse files Browse the repository at this point in the history
  • Loading branch information
PaytonWebber committed Dec 11, 2024
1 parent 9477d36 commit f338156
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 52 deletions.
54 changes: 53 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ categories = ["algorithms", "data-structures"]

[dependencies]
rand = "0.8.5"
rayon = "1.7"
10 changes: 5 additions & 5 deletions examples/tic_tac_toe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ impl State for TicTacToe {
self.current_player
}

fn step(&self, action: (usize, usize)) -> Self {
fn step(&self, action: &(usize, usize)) -> Self {
let mut new_board = self.board;
new_board[action.0][action.1][self.current_player] = 1;

// Create a new vector excluding the taken action
let mut new_legal_actions = Vec::with_capacity(self.legal_actions.len() - 1);
for &a in &self.legal_actions {
if a != action {
if a != *action {
new_legal_actions.push(a);
}
}
Expand Down Expand Up @@ -140,12 +140,12 @@ fn main() {

// Randomly select the first action
let action = game.legal_actions.choose(&mut rand::thread_rng()).unwrap();
game = game.step(*action);
game = game.step(action);

while !game.is_terminal() {
let mut mcts = Mcts::new(game.clone(), 5.0);
let action = mcts.search(100000);
game = game.step(action);
let action = mcts.search(1000);
game = game.step(&action);
game.render();
}

Expand Down
8 changes: 4 additions & 4 deletions examples/ultimate_tic_tac_toe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl State for UltimateTicTacToe {
(0, 0, 0)
}

fn step(&self, action: (usize, usize, usize)) -> Self {
fn step(&self, action: &(usize, usize, usize)) -> Self {
let mut new_board = self.board.clone();
new_board[action.0][action.1][action.2][self.player as usize] = 1;
let legal_actions =
Expand Down Expand Up @@ -208,7 +208,7 @@ impl UltimateTicTacToe {
fn mini_board_full(board: &[[[[u8; 2]; 3]; 3]; 9], board_idx: usize) -> bool {
for i in 0..3 {
for j in 0..3 {
if board[board_idx][i][j][0] == 0 && board[board_idx][i][j][0] == 0 {
if board[board_idx][i][j][0] == 0 && board[board_idx][i][j][1] == 0 {
return false;
}
}
Expand All @@ -220,12 +220,12 @@ impl UltimateTicTacToe {
fn main() {
let mut game = UltimateTicTacToe::new();
let action = game.legal_actions.choose(&mut rand::thread_rng()).unwrap();
game = game.step(*action);
game = game.step(action);

while !game.is_terminal() {
let mut mcts = Mcts::new(game.clone(), 1.4);
let action = mcts.search(1000);
game = game.step(action);
game = game.step(&action);
game.render();
}

Expand Down
128 changes: 89 additions & 39 deletions src/mcts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ use arena::Arena;
use node::Node;

use rand::seq::SliceRandom;
use rayon::prelude::*;

use std::sync::Mutex;

pub struct Mcts<S: State> {
pub arena: Arena<S>,
pub arena: Mutex<Arena<S>>,
pub root_id: usize,
c: f64,
}
Expand All @@ -18,89 +21,136 @@ impl<S: State + std::fmt::Debug + std::clone::Clone> Mcts<S> {
let mut arena: Arena<S> = Arena::new();
let root: Node<S> = Node::new(state.clone(), S::default_action(), None);
let root_id: usize = arena.add_node(root);
Mcts { arena, root_id, c }
Mcts {
arena: Mutex::new(arena),
root_id,
c,
}
}

pub fn search(&mut self, n: usize) -> S::Action {
for _ in 0..n {
let mut selected_id: usize = self.select();
let selected_node: &Node<S> = self.arena.get_node(selected_id);
if !selected_node.state.is_terminal() {
self.expand(selected_id);
let children: &Vec<usize> = &self.arena.get_node(selected_id).children;
let random_child: usize = children.choose(&mut rand::thread_rng()).unwrap().clone();
selected_id = random_child;
let selected_id = self.select();

// Lock only to check terminality
{
let arena = self.arena.lock().unwrap();
let selected_node = arena.get_node(selected_id);
if !selected_node.state.is_terminal() {
drop(arena);
self.expand(selected_id);
continue;
}
}
let reward: f64 = self.simulate(selected_id);
self.backprop(selected_id, reward);

// Simulate and backprop
let state = {
let arena = self.arena.lock().unwrap();
arena.get_node(selected_id).state.clone()
};
let reward = self.simulate_from_state(state.clone(), state.to_play());
self.backprop(selected_id, reward, 1);
}
let root_node: &Node<S> = self.arena.get_node(self.root_id);
let best_child: usize = root_node

let arena = self.arena.lock().unwrap();
let root_node = arena.get_node(self.root_id);
let best_child = root_node
.children
.iter()
.max_by(|&a, &b| {
let node_a_score = self.arena.get_node(*a).q;
let node_b_score = self.arena.get_node(*b).q;
let node_a_score = arena.get_node(*a).q;
let node_b_score = arena.get_node(*b).q;
node_a_score.partial_cmp(&node_b_score).unwrap()
})
.unwrap()
.clone();

let best_action: S::Action = self.arena.get_node(best_child).action;
best_action
arena.get_node(best_child).action.clone()
}

fn select(&mut self) -> usize {
let mut current: usize = 0;
let mut current: usize = self.root_id;
loop {
let node = &self.arena.get_node(current);
let arena = self.arena.lock().unwrap();
let node = arena.get_node(current);
if node.is_leaf() || node.state.is_terminal() {
return current;
}
let best_child = node.get_best_child(&self.arena, self.c);
let best_child = node.get_best_child(&arena, self.c);
current = best_child;
}
}

fn expand(&mut self, id: usize) {
let parent: &Node<S> = self.arena.get_node_mut(id);
let legal_actions: Vec<S::Action> = parent.state.get_legal_actions();
let parent_state: S = parent.state.clone();
for action in legal_actions {
let state = parent_state.step(action);
let new_node = Node::new(state, action, Some(id));
let new_id = self.arena.add_node(new_node);
self.arena.get_node_mut(id).children.push(new_id);
let children_info = {
let mut arena = self.arena.lock().unwrap();
let parent = arena.get_node_mut(id);

let parent_state = parent.state.clone();
let legal_actions = parent_state.get_legal_actions();

// Create children nodes
let mut children_info = Vec::new();
for action in legal_actions {
let child_state = parent_state.step(&action);
let child_node = Node::new(child_state.clone(), action.clone(), Some(id));
let child_id = arena.add_node(child_node);

children_info.push((child_id, child_state));
}
children_info
};

// add children to parent
let mut arena = self.arena.lock().unwrap();
let parent = arena.get_node_mut(id);
for (child_id, _) in &children_info {
parent.children.push(*child_id);
}
drop(arena);

// Step 2: Parallel simulations outside the lock
let results: Vec<f64> = children_info
.par_iter()
.map(|(_, state)| self.simulate_from_state(state.clone(), state.to_play()))
.collect();

// Step 3: Aggregate results
let total_reward: f64 = results.iter().sum();
let total_visits: usize = results.len();

// Step 4: Backprop once with aggregated results
if total_visits > 0 {
self.backprop(id, total_reward, total_visits);
}
}

fn simulate(&self, id: usize) -> f64 {
let node: &Node<S> = self.arena.get_node(id);
let mut state: S = node.state.clone();
/// Simulate a rollout from a given state until terminal, without locking the arena.
fn simulate_from_state(&self, mut state: S, to_play: usize) -> f64 {
while !state.is_terminal() {
let legal_actions = state.get_legal_actions();
let action = legal_actions
.choose(&mut rand::thread_rng())
.unwrap()
.clone();
state = state.step(action);
state = state.step(&action);
}
let reward: f64 = state.reward(node.state.to_play()) as f64;
reward
state.reward(to_play) as f64
}

fn backprop(&mut self, id: usize, mut reward: f64) {
let mut current: usize = id;
fn backprop(&mut self, id: usize, mut reward: f64, total_n: usize) {
let mut current = id;
loop {
let node = self.arena.get_node_mut(current);
let mut arena = self.arena.lock().unwrap();
let node = arena.get_node_mut(current);
node.reward_sum += reward;
node.n += 1;
node.n += total_n;
node.q = node.reward_sum / node.n as f64;
if let Some(parent_id) = node.parent {
current = parent_id;
} else {
break;
}
// Flip the reward for the parent
reward = -reward;
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/state.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
pub trait State {
type Action: Copy;
pub trait State: Send + Sync {
type Action: Send + Sync + Clone;
fn default_action() -> Self::Action;

fn player_has_won(&self, player: usize) -> bool;
fn is_terminal(&self) -> bool;
fn get_legal_actions(&self) -> Vec<Self::Action>;
fn to_play(&self) -> usize;
fn step(&self, action: Self::Action) -> Self;
fn step(&self, action: &Self::Action) -> Self;
fn reward(&self, to_play: usize) -> f32;
fn render(&self);
}

0 comments on commit f338156

Please sign in to comment.