Skip to content

Commit

Permalink
Fixed a couple of issues (#14)
Browse files Browse the repository at this point in the history
* generalised tic-tac-toe

* added example to README (#10)

* generalized tic-tac-toe game

* added connect 4 game (int#11)

* updated README for connect 4

* added main.py for example

* added connect 4 game (#11)

* updated README for connect 4

* added main.py for example
  • Loading branch information
George-Ogden authored May 23, 2022
1 parent 58771d1 commit 78ca27d
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 30 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
*.py[cod]
common/*.pyc
main.py
mnist.py
output/
output_test/
Expand Down
48 changes: 47 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,50 @@ best_node = mcts.best_action(10000)
If you want to apply MCTS for your own game, its state implementation should derive from
`mmctspy.games.common.TwoPlayersGameState`

(lookup `mctspy.games.examples.tictactoe.TicTacToeGameState` for inspiration)
(lookup `mctspy.games.examples.tictactoe.TicTacToeGameState` for inspiration)

### Example Game Play
```python
import numpy as np
from mctspy.tree.nodes import TwoPlayersGameMonteCarloTreeSearchNode
from mctspy.tree.search import MonteCarloTreeSearch
from mctspy.games.examples.connect4 import Connect4GameState

# define inital state
state = np.zeros((7, 7))
board_state = Connect4GameState(
state=state, next_to_move=np.random.choice([-1, 1]), win=4)

# link pieces to icons
pieces = {0: " ", 1: "X", -1: "O"}

# print a single row of the board
def stringify(row):
return " " + " | ".join(map(lambda x: pieces[int(x)], row)) + " "

# display the whole board
def display(board):
board = board.copy().T[::-1]
for row in board[:-1]:
print(stringify(row))
print("-"*(len(row)*4-1))
print(stringify(board[-1]))
print()

display(board_state.board)
# keep playing until game terminates
while board_state.game_result is None:
# calculate best move
root = TwoPlayersGameMonteCarloTreeSearchNode(state=board_state)
mcts = MonteCarloTreeSearch(root)
best_node = mcts.best_action(total_simulation_seconds=1)

# update board
board_state = best_node.state
# display board
display(board_state.board)

# print result
print(pieces[board_state.game_result])

```
30 changes: 30 additions & 0 deletions mctspy/games/examples/connect4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
from mctspy.games.examples.tictactoe import TicTacToeGameState, TicTacToeMove

class Connect4GameState(TicTacToeGameState):

def is_move_legal(self, move):
# check if correct player moves
if move.value != self.next_to_move:
return False

# check if inside the board on x-axis
x_in_range = (0 <= move.x_coordinate < self.board_size)
if not x_in_range:
return False

# check if inside the board on y-axis
y_in_range = (0 <= move.y_coordinate < self.board_size)
if not y_in_range:
return False

# finally check if board field not occupied yet
return self.board[move.x_coordinate, move.y_coordinate] == 0 and (move.y_coordinate == 0 or self.board[move.x_coordinate, move.y_coordinate-1] != 0)

def get_legal_actions(self):
indices = np.where(np.count_nonzero(self.board,axis=1) != self.board_size)[0]
# print(indices)
return [
TicTacToeMove(i, np.count_nonzero(self.board[i,:]), self.next_to_move)
for i in indices
]
55 changes: 27 additions & 28 deletions mctspy/games/examples/tictactoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,37 @@ class TicTacToeGameState(TwoPlayersAbstractGameState):
x = 1
o = -1

def __init__(self, state, next_to_move=1):
def __init__(self, state, next_to_move=1, win=None):
if len(state.shape) != 2 or state.shape[0] != state.shape[1]:
raise ValueError("Only 2D square boards allowed")
self.board = state
self.board_size = state.shape[0]
if win is None:
win = self.board_size
self.win = win
self.next_to_move = next_to_move

@property
def game_result(self):
# check if game is over
rowsum = np.sum(self.board, 0)
colsum = np.sum(self.board, 1)
diag_sum_tl = self.board.trace()
diag_sum_tr = self.board[::-1].trace()

player_one_wins = any(rowsum == self.board_size)
player_one_wins += any(colsum == self.board_size)
player_one_wins += (diag_sum_tl == self.board_size)
player_one_wins += (diag_sum_tr == self.board_size)

if player_one_wins:
return self.x

player_two_wins = any(rowsum == -self.board_size)
player_two_wins += any(colsum == -self.board_size)
player_two_wins += (diag_sum_tl == -self.board_size)
player_two_wins += (diag_sum_tr == -self.board_size)

if player_two_wins:
return self.o

for i in range(self.board_size - self.win + 1):
rowsum = np.sum(self.board[i:i+self.win], 0)
colsum = np.sum(self.board[:,i:i+self.win], 1)
if rowsum.max() == self.win or colsum.max() == self.win:
return self.x
if rowsum.min() == -self.win or colsum.min() == -self.win:
return self.o
for i in range(self.board_size - self.win + 1):
for j in range(self.board_size - self.win + 1):
sub = self.board[i:i+self.win,j:j+self.win]
diag_sum_tl = sub.trace()
diag_sum_tr = sub[::-1].trace()
if diag_sum_tl == self.win or diag_sum_tr == self.win:
return self.x
if diag_sum_tl == -self.win or diag_sum_tr == -self.win:
return self.o

# draw
if np.all(self.board != 0):
return 0.

Expand All @@ -76,7 +76,7 @@ def is_move_legal(self, move):
if not y_in_range:
return False

# finally check if board field not occupied yet
# finally check if board field not occupied ye
return self.board[move.x_coordinate, move.y_coordinate] == 0

def move(self, move):
Expand All @@ -86,12 +86,11 @@ def move(self, move):
)
new_board = np.copy(self.board)
new_board[move.x_coordinate, move.y_coordinate] = move.value
if self.next_to_move == TicTacToeGameState.x:
next_to_move = TicTacToeGameState.o
if self.next_to_move == self.x:
next_to_move = self.o
else:
next_to_move = TicTacToeGameState.x

return TicTacToeGameState(new_board, next_to_move)
next_to_move = self.x
return type(self)(new_board, next_to_move, self.win)

def get_legal_actions(self):
indices = np.where(self.board == 0)
Expand Down

0 comments on commit 78ca27d

Please sign in to comment.