Skip to content

Commit

Permalink
polish(pu): polish visit_count_to_action_dist method, polish reset_ka…
Browse files Browse the repository at this point in the history
…tago_game_state, polish go_alphazero_league_config
  • Loading branch information
puyuan1996 committed Aug 2, 2023
1 parent aec6805 commit bfeeaf5
Show file tree
Hide file tree
Showing 12 changed files with 395 additions and 158 deletions.
49 changes: 44 additions & 5 deletions lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <functional>
#include <iostream>
#include <memory>
#include <numeric>

namespace py = pybind11;

Expand Down Expand Up @@ -181,11 +182,16 @@ class MCTS {
// std::cout << std::endl;

// 计算action_probs
std::vector<double> visit_logs;
for (int v : visits) {
visit_logs.push_back(std::log(v + 1e-10));
}
std::vector<double> action_probs = softmax(visit_logs, temperature);
// std::vector<double> visit_logs;
// for (int v : visits) {
// visit_logs.push_back(std::log(v + 1e-10));
// }
// std::vector<double> action_probs = softmax(visit_logs, temperature);

// 将visits转换为std::vector<double>
std::vector<double> visits_d(visits.begin(), visits.end());
std::vector<double> action_probs = visit_count_to_action_distribution(visits_d, temperature);

// std::cout << "position15 " << std::endl;
// 根据action_probs选择一个action
int action;
Expand Down Expand Up @@ -257,7 +263,40 @@ class MCTS {
}
}





private:
static std::vector<double> visit_count_to_action_distribution(const std::vector<double>& visits, double temperature) {
// Check if temperature is 0
if (temperature == 0) {
throw std::invalid_argument("Temperature cannot be 0");
}

// Check if all visit counts are 0
if (std::all_of(visits.begin(), visits.end(), [](double v){ return v == 0; })) {
throw std::invalid_argument("All visit counts cannot be 0");
}

std::vector<double> normalized_visits(visits.size());

// Divide visit counts by temperature
for (size_t i = 0; i < visits.size(); i++) {
normalized_visits[i] = visits[i] / temperature;
}

// Calculate the sum of all normalized visit counts
double sum = std::accumulate(normalized_visits.begin(), normalized_visits.end(), 0.0);

// Normalize the visit counts
for (double& visit : normalized_visits) {
visit /= sum;
}

return normalized_visits;
}

static std::vector<double> softmax(const std::vector<double>& values, double temperature) {
std::vector<double> exps;
double sum = 0.0;
Expand Down
18 changes: 14 additions & 4 deletions lzero/mcts/ptree/ptree_az.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(self, cfg: EasyDict) -> None:
'root_dirichlet_alpha', 0.3
) # 0.3 # for chess, 0.03 for Go and 0.15 for shogi.
self._root_noise_weight = self._cfg.get('root_noise_weight', 0.25) # 0.25
self.mcts_search_cnt = 0

def get_next_action(
self,
Expand Down Expand Up @@ -167,14 +168,22 @@ def get_next_action(
action_visits.append((action, 0))

actions, visits = zip(*action_visits)
print('action_visits= {}'.format(visits))
action_probs = nn.functional.softmax(1.0 / temperature * np.log(torch.as_tensor(visits) + 1e-10), dim=0).numpy()
# print('action_visits= {}'.format(visits))
# original code: visit_count 0 -> action_probs small positive number
# action_probs = nn.functional.softmax(np.log(torch.as_tensor(visits) + 1e-10) / temperature, dim=0).numpy()

visits_t = torch.as_tensor(visits, dtype=torch.float32)
visits_t /= temperature
action_probs = (visits_t / visits_t.sum()).numpy()

if sample:
action = np.random.choice(actions, p=action_probs)
else:
action = actions[np.argmax(action_probs)]
print('action= {}'.format(action))
print('action_probs= {}'.format(action_probs))
self.mcts_search_cnt += 1
# print(f'mcts_search_cnt: {self.mcts_search_cnt}')
# print('action= {}'.format(action))
# print('action_probs= {}'.format(action_probs))
return action, action_probs

def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn: Callable) -> None:
Expand All @@ -188,6 +197,7 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn:
- policy_forward_fn (:obj:`Function`): The Callable to compute the action probs and state value.
"""
while not node.is_leaf():
# print('here')
# print(node.children.keys())
action, node = self._select_child(node, simulate_env)
if action is None:
Expand Down
38 changes: 28 additions & 10 deletions lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def _init_learn(self) -> None:
from torch.optim.lr_scheduler import LambdaLR
max_step = self._cfg.threshold_training_steps_for_final_lr
# NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr.
lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa
# lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa
lr_lambda = lambda step: 1 if step < max_step * 0.33 else (0.1 if step < max_step * 0.66 else 0.01) # noqa

self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda)

# Algorithm config
Expand All @@ -160,12 +162,22 @@ def _init_learn(self) -> None:
self._learn_model = torch.compile(self._learn_model)

def _forward_learn(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, float]:
for d in inputs:
if 'katago_game_state' in d:
del d['katago_game_state']
print('delete katago_game_state')

inputs = default_collate(inputs)
for input_dict in inputs:
# Check and remove 'katago_game_state' from 'obs' if it exists
if 'katago_game_state' in input_dict['obs']:
del input_dict['obs']['katago_game_state']

# Check and remove 'katago_game_state' from 'next_obs' if it exists
if 'katago_game_state' in input_dict['next_obs']:
del input_dict['next_obs']['katago_game_state']
try:
# list of dict -> dict of list
inputs = default_collate(inputs)
except Exception as e:
print(f"Exception occurred: {e}")
print(f"Type of inputs: {type(inputs)}")
print(f"Is default_collate callable? {callable(default_collate)}")
raise

if self._cuda:
inputs = to_device(inputs, self._device)
Expand Down Expand Up @@ -253,6 +265,7 @@ def _forward_collect(self, envs: Dict, obs: Dict, temperature: float = 1) -> Dic
self.collect_mcts_temperature = temperature
ready_env_id = list(envs.keys())
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
katago_game_state = {env_id: obs[env_id]['katago_game_state'] for env_id in ready_env_id}
start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id}
output = {}
self._policy_model = self._collect_model
Expand All @@ -264,6 +277,7 @@ def _forward_collect(self, envs: Dict, obs: Dict, temperature: float = 1) -> Dic
init_state=init_state[env_id],
# katago_policy_init=False,
katago_policy_init=True,
katago_game_state=katago_game_state[env_id],
)
# action, mcts_probs = self._collect_mcts.get_next_action(
# envs[env_id],
Expand Down Expand Up @@ -316,6 +330,7 @@ def _forward_eval(self, envs: Dict, obs: Dict) -> Dict[str, torch.Tensor]:
"""
ready_env_id = list(obs.keys())
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
katago_game_state = {env_id: obs[env_id]['katago_game_state'] for env_id in ready_env_id}
start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id}
output = {}
self._policy_model = self._eval_model
Expand All @@ -326,9 +341,9 @@ def _forward_eval(self, envs: Dict, obs: Dict) -> Dict[str, torch.Tensor]:
envs[env_id].reset(
start_player_index=start_player_index[env_id],
init_state=init_state[env_id],
# katago_policy_init=False,
katago_policy_init=False,
# TODO(pu)
katago_policy_init=True,
# katago_policy_init=True,
)
try:
action, mcts_probs = self._eval_mcts.get_next_action(envs[env_id], self._policy_value_fn, 1.0, False)
Expand Down Expand Up @@ -374,8 +389,11 @@ def _process_transition(self, obs: Dict, model_output: Dict[str, torch.Tensor],
Overview:
Generate the dict type transition (one timestep) data from policy learning.
"""
if 'katago_game_state' in obs:
if 'katago_game_state' in obs.keys():
del obs['katago_game_state']
# if 'katago_game_state' in timestep.obs.keys():
# del timestep.obs['katago_game_state']
# Note: used in _foward_collect in alphazero_collector now

return {
'obs': obs,
Expand Down
1 change: 1 addition & 0 deletions lzero/worker/alphazero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def collect(self,
self._obs_pool.update(obs_)
simulation_envs = {}
for env_id in ready_env_id:
# TODO(pu)
# create the new simulation env instances from the current collect env using the same env_config.
simulation_envs[env_id] = self._env._env_fn[env_id]()

Expand Down
1 change: 1 addition & 0 deletions zoo/atari/config/atari_muzero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
game_segment_length=400,
use_augmentation=True,
update_per_collect=update_per_collect,
model_update_ratio=0.1,
batch_size=batch_size,
optim_type='SGD',
lr_piecewise_constant_decay=True,
Expand Down
15 changes: 8 additions & 7 deletions zoo/board_games/go/config/go_alphazero_bot_mode_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,6 @@
elif board_size == 6:
komi = 4

collector_env_num = 8
n_episode = 8
evaluator_env_num = 1
update_per_collect = 50
batch_size = 256
max_env_step = int(10e6)

if board_size == 19:
num_simulations = 800
elif board_size == 9:
Expand All @@ -27,6 +20,14 @@
# num_simulations = 80
num_simulations = 50

collector_env_num = 8
n_episode = 8
evaluator_env_num = 1
update_per_collect = 50
batch_size = 256
max_env_step = int(10e6)
num_channels = 64

board_size = 6
komi = 4
collector_env_num = 1
Expand Down
76 changes: 45 additions & 31 deletions zoo/board_games/go/config/go_alphazero_league_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,51 @@
num_simulations = 800
elif board_size == 9:
num_simulations = 180
# num_simulations = 50
elif board_size == 6:
# num_simulations = 80
num_simulations = 50
num_simulations = 80

collector_env_num = 8
n_episode = 8
evaluator_env_num = 1
update_per_collect = 200
# update_per_collect = 200
update_per_collect = None
model_update_ratio = 0.1

batch_size = 256
max_env_step = int(100e6)
snapshot_the_player_in_iter_zero = True
one_phase_step = int(5e3)
# TODO(pu)
sp_prob = 0.5 # 0, 0.5, 1
use_bot_init_historical = False
sp_prob = 0.2 # 0, 0.5, 1
# use_bot_init_historical = False
use_bot_init_historical = True
# num_res_blocks = 5
# num_channels = 64
num_res_blocks = 10
num_channels = 128


# debug config
board_size = 6
komi = 4
collector_env_num = 1
n_episode = 1
evaluator_env_num = 1
update_per_collect = 2
batch_size = 2
max_env_step = int(2e5)
sp_prob = 0.
snapshot_the_player_in_iter_zero = True
one_phase_step = int(5)
num_simulations = 2
num_channels = 2
# board_size = 6
# komi = 4
# collector_env_num = 1
# n_episode = 1
# evaluator_env_num = 1
# update_per_collect = 2
# batch_size = 2
# max_env_step = int(2e5)
# sp_prob = 0.2
# one_phase_step = int(5)
# num_simulations = 5
# num_channels = 2
# num_res_blocks = 1

# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

go_alphazero_league_config = dict(
exp_name=f"data_az_ctree_league/go_b{board_size}-komi-{komi}_alphazero_ns{num_simulations}_upc{update_per_collect}_league-sp-{sp_prob}_bot-init-{use_bot_init_historical}_phase-step-{one_phase_step}_seed0",
exp_name=f"data_az_ctree_league/go_b{board_size}-komi-{komi}_alphazero_nb-{num_res_blocks}-nc-{num_channels}_ns{num_simulations}_upc{update_per_collect}-mur-{model_update_ratio}_league-sp-{sp_prob}_bot-init-{use_bot_init_historical}_phase-step-{one_phase_step}_seed0",
env=dict(
stop_value=2,
env_name="Go",
Expand All @@ -63,8 +70,8 @@
scale=True,
agent_vs_human=False,
use_katago_bot=True,
katago_checkpoint_path="/Users/puyuan/code/KataGo/kata1-b18c384nbt-s6582191360-d3422816034/model.ckpt",
# katago_checkpoint_path="/mnt/nfs/puyuan/KataGo/kata1-b18c384nbt-s6582191360-d3422816034/model.ckpt",
# katago_checkpoint_path="/Users/puyuan/code/KataGo/kata1-b18c384nbt-s6582191360-d3422816034/model.ckpt",
katago_checkpoint_path="/mnt/nfs/puyuan/KataGo/kata1-b18c384nbt-s6582191360-d3422816034/model.ckpt",
ignore_pass_if_have_other_legal_actions=True,
bot_action_type='v0', # {'v0', 'alpha_beta_pruning'}
prob_random_action_in_bot=0,
Expand All @@ -84,7 +91,7 @@
model=dict(
observation_shape=(board_size, board_size, 17),
action_space_size=int(board_size * board_size + 1),
num_res_blocks=1,
num_res_blocks=num_res_blocks,
num_channels=num_channels,
),
# mcts_ctree=False,
Expand All @@ -93,17 +100,24 @@
env_type='board_games',
board_size=board_size,
update_per_collect=update_per_collect,
model_update_ratio=model_update_ratio,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
learning_rate=0.003,
grad_clip_value=0.5,
# optim_type='Adam',
# lr_piecewise_constant_decay=False,
# learning_rate=0.003,

# OpenGo parameters
optim_type='SGD',
lr_piecewise_constant_decay=True,
learning_rate=0.02, # 0.02, 0.002, 0.0002
threshold_training_steps_for_final_lr=int(1.5e6),

# i.e. temperature: 1 -> 0.5 -> 0.25
manual_temperature_decay=True,
threshold_training_steps_for_final_temperature=int(1.5e6),

value_weight=1.0,
entropy_weight=0.0,
# NOTE:In board_games, we set large td_steps to make sure the value target is the final outcome.
td_steps=500,
# NOTE:In board_games, we set discount_factor=1.
discount_factor=1,
n_episode=n_episode,
eval_freq=int(2e3),
# eval_freq=int(100), # debug
Expand Down
Loading

0 comments on commit bfeeaf5

Please sign in to comment.