forked from datamllab/rlcard
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_dmc.py
126 lines (114 loc) · 2.94 KB
/
run_dmc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
''' An example of training a Deep Monte-Carlo (DMC) Agent on PettingZoo environments
wrapping RLCard
'''
import os
import argparse
from pettingzoo.classic import (
leduc_holdem_v4,
texas_holdem_v4,
dou_dizhu_v4,
mahjong_v4,
texas_holdem_no_limit_v6,
uno_v4,
gin_rummy_v4,
)
from rlcard.agents.dmc_agent import DMCTrainer
env_name_to_env_func = {
"leduc-holdem": leduc_holdem_v4,
"limit-holdem": texas_holdem_v4,
"doudizhu": dou_dizhu_v4,
"mahjong": mahjong_v4,
"no-limit-holdem": texas_holdem_no_limit_v6,
"uno": uno_v4,
"gin-rummy": gin_rummy_v4,
}
def train(args):
# Make the environment
env_func = env_name_to_env_func[args.env]
env = env_func.env()
env.reset()
# Initialize the DMC trainer
trainer = DMCTrainer(
env,
is_pettingzoo_env=True,
load_model=args.load_model,
xpid=args.xpid,
savedir=args.savedir,
save_interval=args.save_interval,
num_actor_devices=args.num_actor_devices,
num_actors=args.num_actors,
training_device=args.training_device,
total_frames=args.total_frames,
)
# Train DMC Agents
trainer.start()
if __name__ == '__main__':
parser = argparse.ArgumentParser("DMC example in RLCard")
parser.add_argument(
'--env',
type=str,
default='leduc-holdem',
choices=[
'blackjack',
'leduc-holdem',
'limit-holdem',
'doudizhu',
'mahjong',
'no-limit-holdem',
'uno',
'gin-rummy',
]
)
parser.add_argument(
'--cuda',
type=str,
default='',
)
parser.add_argument(
'--load_model',
action='store_true',
help='Load an existing model',
)
parser.add_argument(
'--xpid',
default='leduc_holdem',
help='Experiment id (default: leduc_holdem)',
)
parser.add_argument(
'--savedir',
default='experiments/dmc_result',
help='Root dir where experiment data will be saved',
)
parser.add_argument(
'--save_interval',
default=30,
type=int,
help='Time interval (in minutes) at which to save the model',
)
parser.add_argument(
'--num_actor_devices',
default=1,
type=int,
help='The number of devices used for simulation',
)
parser.add_argument(
'--num_actors',
default=5,
type=int,
help='The number of actors for each simulation device',
)
parser.add_argument(
'--total_frames',
default=1e11,
type=int,
help='The total number of frames to train for',
)
parser.add_argument(
'--training_device',
default=0,
type=int,
help='The index of the GPU used for training models',
)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
train(args)