diff --git a/.gitignore b/.gitignore index 180afa3..6793d60 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ *.py[cod] common/*.pyc -main.py mnist.py output/ output_test/ diff --git a/README.md b/README.md index c8cfda8..806beeb 100644 --- a/README.md +++ b/README.md @@ -36,4 +36,50 @@ best_node = mcts.best_action(10000) If you want to apply MCTS for your own game, its state implementation should derive from `mmctspy.games.common.TwoPlayersGameState` -(lookup `mctspy.games.examples.tictactoe.TicTacToeGameState` for inspiration) \ No newline at end of file +(lookup `mctspy.games.examples.tictactoe.TicTacToeGameState` for inspiration) + +### Example Game Play +```python +import numpy as np +from mctspy.tree.nodes import TwoPlayersGameMonteCarloTreeSearchNode +from mctspy.tree.search import MonteCarloTreeSearch +from mctspy.games.examples.connect4 import Connect4GameState + +# define inital state +state = np.zeros((7, 7)) +board_state = Connect4GameState( + state=state, next_to_move=np.random.choice([-1, 1]), win=4) + +# link pieces to icons +pieces = {0: " ", 1: "X", -1: "O"} + +# print a single row of the board +def stringify(row): + return " " + " | ".join(map(lambda x: pieces[int(x)], row)) + " " + +# display the whole board +def display(board): + board = board.copy().T[::-1] + for row in board[:-1]: + print(stringify(row)) + print("-"*(len(row)*4-1)) + print(stringify(board[-1])) + print() + +display(board_state.board) +# keep playing until game terminates +while board_state.game_result is None: + # calculate best move + root = TwoPlayersGameMonteCarloTreeSearchNode(state=board_state) + mcts = MonteCarloTreeSearch(root) + best_node = mcts.best_action(total_simulation_seconds=1) + + # update board + board_state = best_node.state + # display board + display(board_state.board) + +# print result +print(pieces[board_state.game_result]) + +``` diff --git a/mctspy/games/examples/connect4.py b/mctspy/games/examples/connect4.py new file mode 100644 index 0000000..750e36a --- /dev/null +++ b/mctspy/games/examples/connect4.py @@ -0,0 +1,30 @@ +import numpy as np +from mctspy.games.examples.tictactoe import TicTacToeGameState, TicTacToeMove + +class Connect4GameState(TicTacToeGameState): + + def is_move_legal(self, move): + # check if correct player moves + if move.value != self.next_to_move: + return False + + # check if inside the board on x-axis + x_in_range = (0 <= move.x_coordinate < self.board_size) + if not x_in_range: + return False + + # check if inside the board on y-axis + y_in_range = (0 <= move.y_coordinate < self.board_size) + if not y_in_range: + return False + + # finally check if board field not occupied yet + return self.board[move.x_coordinate, move.y_coordinate] == 0 and (move.y_coordinate == 0 or self.board[move.x_coordinate, move.y_coordinate-1] != 0) + + def get_legal_actions(self): + indices = np.where(np.count_nonzero(self.board,axis=1) != self.board_size)[0] + # print(indices) + return [ + TicTacToeMove(i, np.count_nonzero(self.board[i,:]), self.next_to_move) + for i in indices + ] diff --git a/mctspy/games/examples/tictactoe.py b/mctspy/games/examples/tictactoe.py index 816da07..4d59bb8 100644 --- a/mctspy/games/examples/tictactoe.py +++ b/mctspy/games/examples/tictactoe.py @@ -21,37 +21,37 @@ class TicTacToeGameState(TwoPlayersAbstractGameState): x = 1 o = -1 - def __init__(self, state, next_to_move=1): + def __init__(self, state, next_to_move=1, win=None): if len(state.shape) != 2 or state.shape[0] != state.shape[1]: raise ValueError("Only 2D square boards allowed") self.board = state self.board_size = state.shape[0] + if win is None: + win = self.board_size + self.win = win self.next_to_move = next_to_move @property def game_result(self): # check if game is over - rowsum = np.sum(self.board, 0) - colsum = np.sum(self.board, 1) - diag_sum_tl = self.board.trace() - diag_sum_tr = self.board[::-1].trace() - - player_one_wins = any(rowsum == self.board_size) - player_one_wins += any(colsum == self.board_size) - player_one_wins += (diag_sum_tl == self.board_size) - player_one_wins += (diag_sum_tr == self.board_size) - - if player_one_wins: - return self.x - - player_two_wins = any(rowsum == -self.board_size) - player_two_wins += any(colsum == -self.board_size) - player_two_wins += (diag_sum_tl == -self.board_size) - player_two_wins += (diag_sum_tr == -self.board_size) - - if player_two_wins: - return self.o - + for i in range(self.board_size - self.win + 1): + rowsum = np.sum(self.board[i:i+self.win], 0) + colsum = np.sum(self.board[:,i:i+self.win], 1) + if rowsum.max() == self.win or colsum.max() == self.win: + return self.x + if rowsum.min() == -self.win or colsum.min() == -self.win: + return self.o + for i in range(self.board_size - self.win + 1): + for j in range(self.board_size - self.win + 1): + sub = self.board[i:i+self.win,j:j+self.win] + diag_sum_tl = sub.trace() + diag_sum_tr = sub[::-1].trace() + if diag_sum_tl == self.win or diag_sum_tr == self.win: + return self.x + if diag_sum_tl == -self.win or diag_sum_tr == -self.win: + return self.o + + # draw if np.all(self.board != 0): return 0. @@ -76,7 +76,7 @@ def is_move_legal(self, move): if not y_in_range: return False - # finally check if board field not occupied yet + # finally check if board field not occupied ye return self.board[move.x_coordinate, move.y_coordinate] == 0 def move(self, move): @@ -86,12 +86,11 @@ def move(self, move): ) new_board = np.copy(self.board) new_board[move.x_coordinate, move.y_coordinate] = move.value - if self.next_to_move == TicTacToeGameState.x: - next_to_move = TicTacToeGameState.o + if self.next_to_move == self.x: + next_to_move = self.o else: - next_to_move = TicTacToeGameState.x - - return TicTacToeGameState(new_board, next_to_move) + next_to_move = self.x + return type(self)(new_board, next_to_move, self.win) def get_legal_actions(self): indices = np.where(self.board == 0)