forked from datamllab/rlcard
-
Notifications
You must be signed in to change notification settings - Fork 1
/
leducholdem.py
111 lines (89 loc) · 3.64 KB
/
leducholdem.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
import os
import numpy as np
from collections import OrderedDict
import rlcard
from rlcard.envs import Env
from rlcard.games.leducholdem import Game
from rlcard.utils import *
DEFAULT_GAME_CONFIG = {
'game_num_players': 2,
}
class LeducholdemEnv(Env):
''' Leduc Hold'em Environment
'''
def __init__(self, config):
''' Initialize the Limitholdem environment
'''
self.name = 'leduc-holdem'
self.default_game_config = DEFAULT_GAME_CONFIG
self.game = Game()
super().__init__(config)
self.actions = ['call', 'raise', 'fold', 'check']
self.state_shape = [[36] for _ in range(self.num_players)]
self.action_shape = [None for _ in range(self.num_players)]
with open(os.path.join(rlcard.__path__[0], 'games/leducholdem/card2index.json'), 'r') as file:
self.card2index = json.load(file)
def _get_legal_actions(self):
''' Get all leagal actions
Returns:
encoded_action_list (list): return encoded legal action list (from str to int)
'''
return self.game.get_legal_actions()
def _extract_state(self, state):
''' Extract the state representation from state dictionary for agent
Note: Currently the use the hand cards and the public cards. TODO: encode the states
Args:
state (dict): Original state from the game
Returns:
observation (list): combine the player's score and dealer's observable score for observation
'''
extracted_state = {}
legal_actions = OrderedDict({self.actions.index(a): None for a in state['legal_actions']})
extracted_state['legal_actions'] = legal_actions
public_card = state['public_card']
hand = state['hand']
obs = np.zeros(36)
obs[self.card2index[hand]] = 1
if public_card:
obs[self.card2index[public_card]+3] = 1
obs[state['my_chips']+6] = 1
obs[sum(state['all_chips'])-state['my_chips']+21] = 1
extracted_state['obs'] = obs
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['legal_actions']]
extracted_state['action_record'] = self.action_recorder
return extracted_state
def get_payoffs(self):
''' Get the payoff of a game
Returns:
payoffs (list): list of payoffs
'''
return self.game.get_payoffs()
def _decode_action(self, action_id):
''' Decode the action for applying to the game
Args:
action id (int): action id
Returns:
action (str): action for the game
'''
legal_actions = self.game.get_legal_actions()
if self.actions[action_id] not in legal_actions:
if 'check' in legal_actions:
return 'check'
else:
return 'fold'
return self.actions[action_id]
def get_perfect_information(self):
''' Get the perfect information of the current state
Returns:
(dict): A dictionary of all the perfect information of the current state
'''
state = {}
state['chips'] = [self.game.players[i].in_chips for i in range(self.num_players)]
state['public_card'] = self.game.public_card.get_index() if self.game.public_card else None
state['hand_cards'] = [self.game.players[i].hand.get_index() for i in range(self.num_players)]
state['current_round'] = self.game.round_counter
state['current_player'] = self.game.game_pointer
state['legal_actions'] = self.game.get_legal_actions()
return state