-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.spec.py
61 lines (46 loc) · 1.61 KB
/
model.spec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#
# Test abstract game interface for things
#
import unittest
import random
from model import Connect4Model
from agent import Agent
from mcts import MCTS
from abstractgame import AbstractGame
from connectgame import ConnectGame
from compare import compare
from connect4adapter import connect4adapter_value
import numpy as np
def agent(hardness):
def quick_mcts(agi):
return MCTS(agi, iteration_number=hardness)
return Agent(quick_mcts, temperature = 1)
def game_creator():
return AbstractGame(ConnectGame())
class TestModelTrains(unittest.TestCase):
def test_model_overfits(self):
m = Connect4Model()
times_to_play = 5
one_wins, two_wins, record = compare(
agent(10),
agent(10),
times_to_play,
game_creator, verbose=False)
X, Y = connect4adapter_value(record)
loss_one = m.train(np.array(X),np.array(Y))
loss_two = m.train(np.array(X),np.array(Y))
self.assertTrue(loss_one > loss_two)
def test_model_fits(self):
m = Connect4Model()
for n in range(50):
times_to_play = 20
one_wins, two_wins, record = compare(
agent(150),
agent(150),
times_to_play,
game_creator, verbose=False)
X, Y = connect4adapter_value(record)
loss = m.train(np.array(X),np.array(Y))
print loss
if __name__ == '__main__':
unittest.main()