Skip to content

Commit

Permalink
fix(pu): add reset_katago_game_state() method in go_env reset()
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Jul 26, 2023
1 parent 5372979 commit aec6805
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 99 deletions.
24 changes: 12 additions & 12 deletions lzero/entry/train_alphazero_league.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,18 @@ def load_checkpoint_fn(player_id: str, ckpt_path: str):
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
league_iter = 0
while True:
# if evaluator.should_eval(main_learner.train_iter):
# stop_flag, eval_episode_info = evaluator.eval(
# main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep
# )
# win_loss_result = win_loss_draw(eval_episode_info)
#
# # set eval bot rating as 100.
# main_player.rating = league.metric_env.rate_1vsC(
# main_player.rating, league.metric_env.create_rating(mu=100, sigma=1e-8), win_loss_result
# )
# if stop_flag:
# break
if evaluator.should_eval(main_learner.train_iter):
stop_flag, eval_episode_info = evaluator.eval(
main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep
)
win_loss_result = win_loss_draw(eval_episode_info)

# set eval bot rating as 100.
main_player.rating = league.metric_env.rate_1vsC(
main_player.rating, league.metric_env.create_rating(mu=100, sigma=1e-8), win_loss_result
)
if stop_flag:
break

for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts):
tb_logger.add_scalar(
Expand Down
8 changes: 7 additions & 1 deletion lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def _init_learn(self) -> None:
self._learn_model = torch.compile(self._learn_model)

def _forward_learn(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, float]:
for d in inputs:
if 'katago_game_state' in d:
del d['katago_game_state']
print('delete katago_game_state')

inputs = default_collate(inputs)

if self._cuda:
Expand Down Expand Up @@ -257,7 +262,8 @@ def _forward_collect(self, envs: Dict, obs: Dict, temperature: float = 1) -> Dic
envs[env_id].reset(
start_player_index=start_player_index[env_id],
init_state=init_state[env_id],
katago_policy_init=False,
# katago_policy_init=False,
katago_policy_init=True,
)
# action, mcts_probs = self._collect_mcts.get_next_action(
# envs[env_id],
Expand Down
62 changes: 38 additions & 24 deletions zoo/board_games/go/config/go_alphazero_bot_mode_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,53 @@
# begin of the most frequently changed config specified by the user
# ==============================================================
board_size = 6
# board_size = 9

if board_size in [9, 19]:
komi = 7.5
elif board_size == 6:
komi = 4

collector_env_num = 8
n_episode = 8
evaluator_env_num = 5
evaluator_env_num = 1
update_per_collect = 50
batch_size = 256
max_env_step = int(10e6)
prob_random_action_in_bot = 1

if board_size == 19:
num_simulations = 800
elif board_size == 9:
num_simulations = 180
# num_simulations = 180
num_simulations = 50
elif board_size == 6:
num_simulations = 80
# num_simulations = 80
num_simulations = 50

board_size = 6
komi = 4
collector_env_num = 1
n_episode = 1
evaluator_env_num = 1
num_simulations = 2
update_per_collect = 2
batch_size = 2
max_env_step = int(5e5)
prob_random_action_in_bot = 0.
num_channels = 2

# board_size = 6
# collector_env_num = 1
# n_episode = 1
# evaluator_env_num = 1
# num_simulations = 2
# update_per_collect = 2
# batch_size = 2
# max_env_step = int(5e5)
# prob_random_action_in_bot = 0.
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
gomoku_alphazero_config = dict(
go_alphazero_config = dict(
exp_name=
f'data_az_ptree/gomoku_alphazero_bot-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_seed0',
f'data_az_ptree/go_alphazero_bot-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_seed0',
env=dict(
board_size=board_size,
komi=7.5,
use_katago_bot=True,
katago_checkpoint_path="/Users/puyuan/code/KataGo/kata1-b18c384nbt-s6582191360-d3422816034/model.ckpt",
# katago_checkpoint_path="/mnt/nfs/puyuan/KataGo/kata1-b18c384nbt-s6582191360-d3422816034/model.ckpt",
battle_mode='play_with_bot_mode',
bot_action_type='v0',
prob_random_action_in_bot=prob_random_action_in_bot,
Expand All @@ -54,9 +67,10 @@
observation_shape=(board_size, board_size, 17),
action_space_size=int(1 * board_size * board_size + 1),
num_res_blocks=1,
num_channels=64,
num_channels=num_channels,
),
mcts_ctree=False,
# mcts_ctree=False,
mcts_ctree=True,
cuda=True,
board_size=board_size,
update_per_collect=update_per_collect,
Expand All @@ -76,13 +90,13 @@
),
)

gomoku_alphazero_config = EasyDict(gomoku_alphazero_config)
main_config = gomoku_alphazero_config
go_alphazero_config = EasyDict(go_alphazero_config)
main_config = go_alphazero_config

gomoku_alphazero_create_config = dict(
go_alphazero_create_config = dict(
env=dict(
type='gomoku',
import_names=['zoo.board_games.gomoku.envs.gomoku_env'],
type='go_lightzero',
import_names=['zoo.board_games.go.envs.go_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
Expand All @@ -99,8 +113,8 @@
import_names=['lzero.worker.alphazero_evaluator'],
)
)
gomoku_alphazero_create_config = EasyDict(gomoku_alphazero_create_config)
create_config = gomoku_alphazero_create_config
go_alphazero_create_config = EasyDict(go_alphazero_create_config)
create_config = go_alphazero_create_config

if __name__ == '__main__':
if main_config.policy.tensor_float_32:
Expand Down
42 changes: 23 additions & 19 deletions zoo/board_games/go/config/go_alphazero_league_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,39 @@
sp_prob = 0.5 # 0, 0.5, 1
use_bot_init_historical = False

# collector_env_num = 1
# n_episode = 1
# evaluator_env_num = 1
# update_per_collect = 2
# batch_size = 2
# max_env_step = int(2e5)
# sp_prob = 0.
# snapshot_the_player_in_iter_zero = True
# one_phase_step = int(5)
# num_simulations = 2
# debug config
board_size = 6
komi = 4
collector_env_num = 1
n_episode = 1
evaluator_env_num = 1
update_per_collect = 2
batch_size = 2
max_env_step = int(2e5)
sp_prob = 0.
snapshot_the_player_in_iter_zero = True
one_phase_step = int(5)
num_simulations = 2
num_channels = 2

# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

go_alphazero_league_config = dict(
exp_name=f"data_az_ptree_league/go_b{board_size}-komi-{komi}_alphazero_ns{num_simulations}_upc{update_per_collect}_league-sp-{sp_prob}_bot-init-{use_bot_init_historical}_phase-step-{one_phase_step}_seed0",
exp_name=f"data_az_ctree_league/go_b{board_size}-komi-{komi}_alphazero_ns{num_simulations}_upc{update_per_collect}_league-sp-{sp_prob}_bot-init-{use_bot_init_historical}_phase-step-{one_phase_step}_seed0",
env=dict(
stop_value=2,
env_name="Go",
board_size=board_size,
komi=7.5,
komi=komi,
battle_mode='self_play_mode',
mcts_mode='self_play_mode', # only used in AlphaZero
scale=True,
agent_vs_human=False,
use_katago_bot=True,
# katago_checkpoint_path="/Users/puyuan/code/KataGo/kata1-b18c384nbt-s6582191360-d3422816034/model.ckpt",
katago_checkpoint_path="/mnt/nfs/puyuan/KataGo/kata1-b18c384nbt-s6582191360-d3422816034/model.ckpt",
katago_checkpoint_path="/Users/puyuan/code/KataGo/kata1-b18c384nbt-s6582191360-d3422816034/model.ckpt",
# katago_checkpoint_path="/mnt/nfs/puyuan/KataGo/kata1-b18c384nbt-s6582191360-d3422816034/model.ckpt",
ignore_pass_if_have_other_legal_actions=True,
bot_action_type='v0', # {'v0', 'alpha_beta_pruning'}
prob_random_action_in_bot=0,
Expand All @@ -81,16 +85,16 @@
observation_shape=(board_size, board_size, 17),
action_space_size=int(board_size * board_size + 1),
num_res_blocks=1,
num_channels=64,
num_channels=num_channels,
),
mcts_ctree=False,
# mcts_ctree=True,
# mcts_ctree=False,
mcts_ctree=True,
cuda=True,
env_type='board_games',
board_size=board_size,
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='AdamW',
optim_type='Adam',
lr_piecewise_constant_decay=False,
learning_rate=0.003,
grad_clip_value=0.5,
Expand All @@ -110,7 +114,7 @@
log_freq_for_payoff_rank=50,
player_category=['go'],
# path to save policy of league player, user can specify this field
path_policy=f"data_az_ptree_league/go_alphazero_league_sp-{sp_prob}_bot-init-{use_bot_init_historical}_phase-step-{one_phase_step}_ns{num_simulations}_policy_ckpt_seed0",
path_policy=f"data_az_ctree_league/go_alphazero_league_sp-{sp_prob}_bot-init-{use_bot_init_historical}_phase-step-{one_phase_step}_ns{num_simulations}_policy_ckpt_seed0",
active_players=dict(main_player=1, ),
main_player=dict(
# An active player will be considered trained enough for snapshot after two phase steps.
Expand Down
27 changes: 14 additions & 13 deletions zoo/board_games/go/config/go_alphazero_sp_mode_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
update_per_collect = 50
batch_size = 256
max_env_step = int(10e6)
num_channels = 64

if board_size == 19:
num_simulations = 800
Expand All @@ -28,18 +29,17 @@
# num_simulations = 80
num_simulations = 50

# board_size = 6
# komi = 4
# # board_size = 9
# # komi = 7.5
# collector_env_num = 1
# n_episode = 1
# evaluator_env_num = 1
# num_simulations = 2
# update_per_collect = 2
# batch_size = 2
# max_env_step = int(5e5)
# prob_random_action_in_bot = 0.
board_size = 6
komi = 4
collector_env_num = 1
n_episode = 1
evaluator_env_num = 1
num_simulations = 2
update_per_collect = 2
batch_size = 2
max_env_step = int(5e4)
prob_random_action_in_bot = 0.
num_channels = 2
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
Expand Down Expand Up @@ -70,9 +70,10 @@
observation_shape=(board_size, board_size, 17),
action_space_size=int(board_size * board_size + 1),
num_res_blocks=1,
num_channels=64,
num_channels=num_channels,
),
# mcts_ctree=False,
mcts_ctree=True,
cuda=True,
board_size=board_size,
update_per_collect=update_per_collect,
Expand Down
Loading

0 comments on commit aec6805

Please sign in to comment.