-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 4df319c
Showing
110 changed files
with
2,436 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
This is a modular architecture for model based reinforcement learning using search. | ||
|
||
The components are separated, facilitating the creation of the agents and | ||
the extension of the existing components | ||
|
||
|
||
run test.py for an example: it will ask you to choose from different components | ||
and then run. |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from environments.environment import Environment | ||
import gym | ||
from gym.envs.classic_control.cartpole import CartPoleEnv | ||
|
||
|
||
|
||
''' | ||
Adapted from https://github.com/werner-duvaud/muzero-general | ||
''' | ||
|
||
class CartPole(Environment): | ||
def __init__(self,max_steps=1000): | ||
self.environment = gym.make("CartPole-v1") | ||
self.max_steps = max_steps | ||
|
||
def step(self,action): | ||
assert not self.done, "can not execute steps when game has finished" | ||
assert self.steps_taken < self.max_steps | ||
self.steps_taken +=1 | ||
obs, reward, self.done, info = self.environment.step(action) | ||
if self.steps_taken == self.max_steps: | ||
self.done = True | ||
self.current_observation = obs | ||
return obs, reward, self.done, info | ||
|
||
def reset(self): | ||
self.done = False | ||
self.steps_taken = 0 | ||
self.current_observation = self.environment.reset() | ||
return self.current_observation | ||
|
||
def close(self): | ||
self.environment.close() | ||
|
||
def render(self): | ||
self.environment.render() | ||
|
||
def get_action_size(self): | ||
return 2 | ||
|
||
def get_input_shape(self): | ||
return (4,) | ||
|
||
def get_legal_actions(self): | ||
""" In CartPole, the two actions are always legal """ | ||
if not self.done: | ||
return [0,1] | ||
else: | ||
return [] | ||
|
||
def __str__(self): | ||
return "CartPole-v1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from typing import List, Tuple, Dict | ||
import numpy as np | ||
from abc import abstractmethod | ||
|
||
""" | ||
Environments work with numpy arrays, so don't forget to convert them to torch tensors when appropriate | ||
""" | ||
class Environment: | ||
|
||
def step(self,action:int) -> Tuple[np.ndarray,float,bool,Dict]: | ||
""" return next_observation, reward, done, info.""" | ||
raise NotImplementedError | ||
|
||
def reset(self) -> Tuple[np.ndarray,int,np.ndarray]: | ||
raise NotImplementedError | ||
|
||
def close(self) -> None: | ||
return None | ||
|
||
def render(self) -> None: | ||
raise NotImplementedError | ||
|
||
def get_action_size(self) -> int: | ||
raise NotImplementedError | ||
|
||
def get_input_shape(self) -> Tuple[int]: | ||
raise NotImplementedError | ||
|
||
def get_num_of_players(self) -> int: | ||
""" default for single player environments """ | ||
return 1 | ||
|
||
def get_legal_actions(self) -> List[int]: | ||
"return an empty list when environment has reached the end, for consistency" | ||
raise NotImplementedError | ||
|
||
def get_action_mask(self) -> np.ndarray: | ||
legal_actions = self.get_legal_actions() | ||
mask = np.zeros(self.get_action_size()) | ||
mask[legal_actions] = 1 | ||
assert (np.where(mask == 1)[0] == legal_actions).all() | ||
return mask | ||
|
||
def get_current_player(self) -> int: | ||
""" default for single player environments. | ||
return a player even when environment has reached the end, for consistency """ | ||
return 0 | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import gym | ||
from gym.core import Env | ||
try: | ||
import gym_minigrid | ||
except ModuleNotFoundError: | ||
raise ModuleNotFoundError('Please run "pip install gym_minigrid"') | ||
import numpy as np | ||
from environments.environment import Environment | ||
import random | ||
|
||
|
||
''' | ||
Original gym environment: https://github.com/maximecb/gym-minigrid | ||
set agent_start_pos to None for it to be a random every time you reset | ||
''' | ||
|
||
|
||
class Minigrid(Environment): | ||
def __init__(self,N=6,reward_scaling=1, max_steps=None,agent_start_pos=(1,1),seed=None): | ||
self.reward_scaling = reward_scaling | ||
self.max_steps = max_steps | ||
self.environment = gym_minigrid.envs.empty.EmptyEnv(size=N+2,agent_start_pos=agent_start_pos) | ||
self.environment = gym_minigrid.wrappers.ImgObsWrapper(self.environment) | ||
if seed is not None: | ||
self.environment.seed(seed) | ||
|
||
def step(self, action): | ||
assert not self.done, "can not execute steps when game has finished" | ||
assert self.steps_taken < self.max_steps | ||
assert action in [0,1,2] | ||
self.steps_taken +=1 | ||
obs, reward, self.done, info = self.environment.step(action) | ||
if self.steps_taken == self.max_steps: | ||
self.done = True | ||
if reward > 0: | ||
#Ths minigrid gives a reward according to how many steps it took before, | ||
#which goes against the markov property | ||
reward = 1 | ||
return obs, self.reward_scaling*reward, self.done, info | ||
|
||
def reset(self): | ||
self.done = False | ||
self.steps_taken = 0 | ||
return np.array(self.environment.reset()) | ||
|
||
def close(self): | ||
self.environment.close() | ||
|
||
def render(self): | ||
return self.environment.render() | ||
|
||
def get_action_size(self): | ||
return 3 | ||
|
||
def get_input_shape(self): | ||
return (7,7,3) | ||
|
||
def get_num_of_players(self): | ||
return 1 | ||
|
||
def get_legal_actions(self): | ||
if self.done: | ||
return [] | ||
else: | ||
return [0,1,2] | ||
|
||
def __str__(self): | ||
return "MiniGrid" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import numpy as np | ||
from copy import deepcopy | ||
from environments.environment import Environment | ||
|
||
|
||
""" This agent, when playing against expert has 3 difficulties: | ||
0 - only random plays | ||
1 - defends when obvious and attacks randomly | ||
2 - defends and attacks when obvious, random otherwise | ||
""" | ||
|
||
class TicTacToe(Environment): | ||
def __init__(self,self_play=True,expert_start=False,expert_difficulty=2): | ||
self.self_play = self_play | ||
self.expert_start = expert_start | ||
self.board = None | ||
assert expert_difficulty in [0,1,2] | ||
self.expert_difficulty = expert_difficulty | ||
|
||
def step(self, action): | ||
assert not self.done, "can not execute steps when game has finished" | ||
if self.board is None: raise ValueError("Call Reset first") | ||
board, reward, done, _ = self._step(action) | ||
if not self.self_play and not done: | ||
action = self._expert_action() | ||
board, reward, done, _ = self._step(action) | ||
reward = -1 * reward | ||
return board, reward, done, {} | ||
|
||
def reset(self): | ||
self.done = False | ||
self.board = np.zeros((2, 3, 3), dtype="int32") | ||
self.current_player = 0 | ||
if not self.self_play: | ||
if self.expert_start: | ||
action = self._expert_action() | ||
self._step(action) | ||
self.expert_start = (self.expert_start == False) #alternate | ||
return deepcopy(self.board) | ||
|
||
def render(self): | ||
if self.board is None: raise ValueError("Call Reset first") | ||
print(self.board[0] - self.board[1]) | ||
|
||
def get_action_size(self): | ||
return 9 | ||
|
||
def get_input_shape(self): | ||
return (2,3,3) | ||
|
||
def get_num_of_players(self): | ||
if self.self_play: | ||
return 2 | ||
else: | ||
return 1 | ||
|
||
def get_legal_actions(self): | ||
if self.board is None: raise ValueError("Call Reset first") | ||
if self.done: return [] | ||
legal = [] | ||
for action in range(9): | ||
row, col = self._action_to_pos(action) | ||
if self.board[0][row, col] == 0 and self.board[1][row, col] == 0: | ||
legal.append(action) | ||
return deepcopy(legal) | ||
|
||
def get_current_player(self) -> int: | ||
if self.board is None: raise ValueError("Call Reset first") | ||
if self.self_play is False: | ||
return 0 | ||
else: | ||
assert self.current_player in [0,1] | ||
return self.current_player | ||
|
||
|
||
def _step(self,action): | ||
if self.done == True: | ||
raise ValueError("Game is over") | ||
row,col = self._action_to_pos(action) | ||
if self.board[0][row,col] != 0 or self.board[1][row,col] != 0: | ||
raise ValueError("Playing in already filled position") | ||
|
||
self.board[0,row, col] = 1 | ||
self.board = np.array([self.board[1],self.board[0]]) #switch | ||
self.done = self._have_winner() or len(self.get_legal_actions()) == 0 | ||
reward = 1 if self._have_winner() else 0 | ||
self.current_player = (self.current_player + 1) % 2 | ||
|
||
return deepcopy(self.board), reward, self.done, {} | ||
|
||
def _action_to_pos(self,action): | ||
assert action >= 0 and action <= 8 | ||
row = action // 3 | ||
col = action % 3 | ||
return (row,col) | ||
|
||
def _pos_to_action(self,row,col): | ||
action = row * 3 + col | ||
return action | ||
|
||
def _have_winner(self): | ||
# Horizontal and vertical checks | ||
for i in range(3): | ||
if (self.board[0,i] == 1).all() or (self.board[1,i] == 1).all(): | ||
return True #horizontal | ||
if (self.board[0,:,i] == 1).all() or (self.board[1,:,i] == 1).all(): | ||
return True #verticals | ||
|
||
#diagonals | ||
if (self.board[0,0,0] == 1 and self.board[0,1,1] == 1 and self.board[0,2,2] == 1 or \ | ||
self.board[1,0,0] == 1 and self.board[1,1,1] == 1 and self.board[1,2,2] == 1 | ||
): | ||
return True | ||
|
||
|
||
if (self.board[0,0,2] == 1 and self.board[0,1,1] == 1 and self.board[0,2,0] == 1 or \ | ||
self.board[1,0,2] == 1 and self.board[1,1,1] == 1 and self.board[1,2,0] == 1 | ||
): | ||
return True | ||
|
||
return False | ||
|
||
def _expert_action(self): | ||
board = self.board | ||
summed_board = 1*board[0] + -1*board[1] | ||
action = np.random.choice(self.get_legal_actions()) | ||
|
||
|
||
# Horizontal and vertical checks | ||
if self.expert_difficulty == 2: | ||
for i in range(3): | ||
if sum(summed_board[i,:]) == 2: #attacking row position | ||
col = np.where(summed_board[i, :] == 0)[0][0] | ||
action = self._pos_to_action(i,col) | ||
return action | ||
|
||
if sum(summed_board[:,i]) == 2: #attacking col position | ||
row = np.where(summed_board[:,i] == 0)[0][0] | ||
action = self._pos_to_action(row,i) | ||
return action | ||
|
||
if self.expert_difficulty >= 1: | ||
for i in range(3): | ||
if sum(summed_board[i,:]) == -2: #defending row position | ||
col = np.where(summed_board[i, :] == 0)[0][0] | ||
action = self._pos_to_action(i,col) | ||
return action | ||
|
||
if sum(summed_board[:,i]) == -2: #defending col position | ||
row = np.where(summed_board[:,i] == 0)[0][0] | ||
action = self._pos_to_action(row,i) | ||
return action | ||
|
||
# Diagonal checks | ||
diag = summed_board.diagonal() #left_up-right_dow | ||
anti_diag = np.fliplr(summed_board).diagonal() #left_down-right_up | ||
if self.expert_difficulty == 2: | ||
if sum(diag) == 2: #attacking diag | ||
ind = np.where(diag == 0)[0][0] | ||
row = ind | ||
col = ind | ||
action = self._pos_to_action(row,col) | ||
return action | ||
|
||
if sum(anti_diag) == 2: #attacking anti-diag | ||
ind = np.where(anti_diag == 0)[0][0] | ||
row = ind | ||
col = 2-ind | ||
action = self._pos_to_action(row,col) | ||
return action | ||
|
||
if self.expert_difficulty >= 1: | ||
if sum(diag) == -2: #defending diag | ||
ind = np.where(diag == 0)[0][0] | ||
row = ind | ||
col = ind | ||
action = self._pos_to_action(row,col) | ||
return action | ||
|
||
if sum(anti_diag) == -2: #defending anti-diag | ||
ind = np.where(anti_diag == 0)[0][0] | ||
row = ind | ||
col = 2-ind | ||
action = self._pos_to_action(row,col) | ||
return action | ||
|
||
return action | ||
|
||
def __str__(self): | ||
return "TicTacToe" | ||
|
Oops, something went wrong.