forked from datamllab/rlcard
-
Notifications
You must be signed in to change notification settings - Fork 1
/
uno_rule_models.py
124 lines (92 loc) · 3.13 KB
/
uno_rule_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
''' UNO rule models
'''
import numpy as np
import rlcard
from rlcard.models.model import Model
class UNORuleAgentV1(object):
''' UNO Rule agent version 1
'''
def __init__(self):
self.use_raw = True
def step(self, state):
''' Predict the action given raw state. A naive rule. Choose the color
that appears least in the hand from legal actions. Try to keep wild
cards as long as it can.
Args:
state (dict): Raw state from the game
Returns:
action (str): Predicted action
'''
legal_actions = state['raw_legal_actions']
state = state['raw_obs']
if 'draw' in legal_actions:
return 'draw'
hand = state['hand']
# If we have wild-4 simply play it and choose color that appears most in hand
for action in legal_actions:
if action.split('-')[1] == 'wild_draw_4':
color_nums = self.count_colors(self.filter_wild(hand))
action = max(color_nums, key=color_nums.get) + '-wild_draw_4'
return action
# Without wild-4, we randomly choose one
action = np.random.choice(self.filter_wild(legal_actions))
return action
def eval_step(self, state):
''' Step for evaluation. The same to step
'''
return self.step(state), []
@staticmethod
def filter_wild(hand):
''' Filter the wild cards. If all are wild cards, we do not filter
Args:
hand (list): A list of UNO card string
Returns:
filtered_hand (list): A filtered list of UNO string
'''
filtered_hand = []
for card in hand:
if not card[2:6] == 'wild':
filtered_hand.append(card)
if len(filtered_hand) == 0:
filtered_hand = hand
return filtered_hand
@staticmethod
def count_colors(hand):
''' Count the number of cards in each color in hand
Args:
hand (list): A list of UNO card string
Returns:
color_nums (dict): The number cards of each color
'''
color_nums = {}
for card in hand:
color = card[0]
if color not in color_nums:
color_nums[color] = 0
color_nums[color] += 1
return color_nums
class UNORuleModelV1(Model):
''' UNO Rule Model version 1
'''
def __init__(self):
''' Load pretrained model
'''
env = rlcard.make('uno')
rule_agent = UNORuleAgentV1()
self.rule_agents = [rule_agent for _ in range(env.num_players)]
@property
def agents(self):
''' Get a list of agents for each position in a the game
Returns:
agents (list): A list of agents
Note: Each agent should be just like RL agent with step and eval_step
functioning well.
'''
return self.rule_agents
@property
def use_raw(self):
''' Indicate whether use raw state and action
Returns:
use_raw (boolean): True if using raw state and action
'''
return True