Skip to content

Commit

Permalink
fix(pu): fix simulate_env_copy.battle_mode and polish softmax in mcts…
Browse files Browse the repository at this point in the history
…_alphazero.cpp
  • Loading branch information
puyuan1996 committed Jul 26, 2023
1 parent db6cca1 commit 5372979
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 55 deletions.
70 changes: 35 additions & 35 deletions lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <pybind11/stl.h>
#include <functional>
#include <iostream>
#include <memory>

namespace py = pybind11;

Expand Down Expand Up @@ -62,15 +63,14 @@ class MCTS {

// 在MCTS类中定义_select_child和_expand_leaf_node函数
std::pair<int, Node*> _select_child(Node* node, py::object simulate_env) {

int action = -1;
Node* child = nullptr;
double best_score = -9999999;
for (const auto& kv : node->children) {
int action_tmp = kv.first;
Node* child_tmp = kv.second;
// Node* child_tmp = kv.second.get();

// py::list legal_actions_py = simulate_env.attr("legal_actions")().cast<py::list>();
py::list legal_actions_py = simulate_env.attr("legal_actions").cast<py::list>();

std::vector<int> legal_actions;
Expand Down Expand Up @@ -129,14 +129,13 @@ class MCTS {
// }

// std::cout << "position16 " << std::endl;

for (const auto& kv : action_probs_dict) {
// std::cout << "position17 " << std::endl;

int action = kv.first;
double prior_p = kv.second;
if (std::find(legal_actions.begin(), legal_actions.end(), action) != legal_actions.end()) {
node->children[action] = new Node(node, prior_p);
// node->children[action] = std::make_unique<Node>(node, prior_p);
}
}
// std::cout << "position18 " << std::endl;
Expand All @@ -145,62 +144,49 @@ class MCTS {
}

std::pair<int, std::vector<double>> get_next_action(py::object simulate_env, py::object policy_forward_fn, double temperature, bool sample) {
// std::cout << "position1 " << std::endl;
// printf("position1 \n");
Node* root = new Node();
// std::cout << "position2 " << std::endl;
_expand_leaf_node(root, simulate_env, policy_forward_fn);
// std::cout << "position3 " << std::endl;

if (sample) {
// std::cout << "position4 " << std::endl;
_add_exploration_noise(root);
// std::cout << "position5 " << std::endl;

}
// std::cout << "position6 " << std::endl;

for (int n = 0; n < num_simulations; ++n) {
// std::cout << "position7 " << std::endl;
py::object simulate_env_copy = simulate_env.attr("clone")();
// std::cout << "position8 " << std::endl;
simulate_env_copy.attr("battle_mode") = simulate_env_copy.attr("mcts_mode");
_simulate(root, simulate_env_copy, policy_forward_fn);
// std::cout << "position9 " << std::endl;
simulate_env_copy = py::none();
// std::cout << "position10 " << std::endl;
}

std::vector<std::pair<int, int>> action_visits;
// std::cout << "position11 " << std::endl;
for (int action = 0; action < simulate_env.attr("action_space").attr("n").cast<int>(); ++action) {
if (root->children.count(action)) {
// std::cout << "position12 " << std::endl;

action_visits.push_back(std::make_pair(action, root->children[action]->visit_count));
} else {
// std::cout << "position13 " << std::endl;

action_visits.push_back(std::make_pair(action, 0));
}
}

// 转换action_visits为两个分离的数组
std::vector<int> actions;
std::vector<int> visits;
// std::cout << "position14 " << std::endl;

for (const auto& av : action_visits) {
actions.push_back(av.first);
visits.push_back(av.second);
}

// std::cout << "Action visits: ";
// for(const auto& visit : visits) {
// std::cout << visit << " ";
// }
// std::cout << std::endl;

// 计算action_probs
std::vector<double> visit_logs;
for (int v : visits) {
visit_logs.push_back(std::log(v + 1e-10));
}
std::vector<double> action_probs = softmax(visit_logs, temperature);
// std::cout << "position15 " << std::endl;

// 根据action_probs选择一个action
int action;
if (sample) {
Expand All @@ -209,13 +195,18 @@ class MCTS {
action = actions[std::distance(action_probs.begin(), std::max_element(action_probs.begin(), action_probs.end()))];
}
// std::cout << "position16 " << std::endl;

// printf("Action: %d\n", action);
// std::cout << "Action probabilities: ";
// for(const auto& prob : action_probs) {
// std::cout << prob << " ";
// }
// std::cout << std::endl;

return std::make_pair(action, action_probs);
}

void _simulate(Node* node, py::object simulate_env, py::object policy_forward_fn) {
// std::cout << "position21 " << std::endl;

while (!node->is_leaf()) {
int action;
std::tie(action, node) = _select_child(node, simulate_env);
Expand All @@ -227,14 +218,12 @@ class MCTS {

bool done;
int winner;
// std::tie(done, winner) = simulate_env.attr("get_done_winner")();
py::tuple result = simulate_env.attr("get_done_winner")();
done = result[0].cast<bool>();
winner = result[1].cast<int>();

double leaf_value;
// std::cout << "position22 " << std::endl;

if (!done) {
leaf_value = _expand_leaf_node(node, simulate_env, policy_forward_fn);
// std::cout << "position23 " << std::endl;
Expand All @@ -261,8 +250,6 @@ class MCTS {
}
}
// std::cout << "position25 " << std::endl;

// if (simulate_env.attr("mcts_mode") == "play_with_bot_mode") {
if (simulate_env.attr("mcts_mode").cast<std::string>() == "play_with_bot_mode") {
node->update_recursive(leaf_value, simulate_env.attr("mcts_mode").cast<std::string>());
} else if (simulate_env.attr("mcts_mode").cast<std::string>() == "self_play_mode") {
Expand All @@ -274,14 +261,20 @@ class MCTS {
static std::vector<double> softmax(const std::vector<double>& values, double temperature) {
std::vector<double> exps;
double sum = 0.0;
// Compute the maximum value
double max_value = *std::max_element(values.begin(), values.end());

// Subtract the maximum value before exponentiating, for numerical stability
for (double v : values) {
double exp_v = std::exp(v / temperature);
double exp_v = std::exp((v - max_value) / temperature);
exps.push_back(exp_v);
sum += exp_v;
}

for (double& exp_v : exps) {
exp_v /= sum;
}

return exps;
}

Expand All @@ -296,19 +289,26 @@ class MCTS {

PYBIND11_MODULE(mcts_alphazero, m) {
py::class_<Node>(m, "Node")
// .def(py::init<Node*, float>())
.def(py::init([](Node* parent, float prior_p){
return new Node(parent ? parent : nullptr, prior_p);
}), py::arg("parent")=nullptr, py::arg("prior_p")=1.0)
.def("value", &Node::get_value)
// .def("value", &Node::get_value)
.def_property_readonly("value", &Node::get_value)
// .def_property("value", &Node::get_value, &Node::set_value)

.def("update", &Node::update)
.def("update_recursive", &Node::update_recursive)
.def("is_leaf", &Node::is_leaf)
.def("is_root", &Node::is_root)
.def("parent", &Node::get_parent)
.def_readwrite("prior_p", &Node::prior_p)

// .def("children", &Node::get_children)
.def_readwrite("children", &Node::children)
.def("visit_count", &Node::get_visit_count);
.def("add_child", &Node::add_child)
.def_readwrite("visit_count", &Node::visit_count)
.def("end_game", &Node::end_game, "A function to end the game");


py::class_<MCTS>(m, "MCTS")
.def(py::init<int, int, double, double, double, double>(),
Expand Down
3 changes: 2 additions & 1 deletion lzero/mcts/ctree/ctree_alphazero/node_alphazero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ PYBIND11_MODULE(node_alphazero, m) {
.def("is_leaf", &Node::is_leaf)
.def("is_root", &Node::is_root)
.def("parent", &Node::get_parent)
// .def("children", &Node::get_children)
.def("children", &Node::get_children)
.def_readwrite("children", &Node::children)
.def("add_child", &Node::add_child)
.def("visit_count", &Node::get_visit_count);
}

Expand Down
94 changes: 82 additions & 12 deletions lzero/mcts/ctree/ctree_alphazero/node_alphazero.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,54 @@
#include <map>
#include <string>
#include <iostream>
#include <memory>
#include <mutex>

class Node {
// std::mutex mtx; // 互斥锁
// std::recursive_mutex mtx; // 递归互斥锁
public:
Node(Node* parent = nullptr, float prior_p = 1.0)
: parent(parent), prior_p(prior_p), visit_count(0), value_sum(0.0) {}

~Node() {
for (auto& pair : children) {
delete pair.second;
}
}

void remove_from_parent() {
if (parent != nullptr) {
parent->children.erase(std::find_if(parent->children.begin(), parent->children.end(),
[this](const std::pair<int, Node*>& pair) { return pair.second == this; }));
}
}

void end_game(Node* root) {
// 假设你在游戏结束时不再需要树中的所有节点
delete_subtree(root);
}

void delete_subtree(Node* node) {
printf("position-ds-1 \n");

for (auto& pair : node->children) {
delete_subtree(pair.second);
}
printf("position-ds-2 \n");

node->remove_from_parent();
printf("position-ds-3 \n");

delete node;
printf("position-ds-4 \n");

}


// node->remove_from_parent();
// delete node;

float get_value() {
return visit_count == 0 ? 0.0 : value_sum / visit_count;
}
Expand All @@ -16,12 +59,32 @@ class Node {
}

void update_recursive(float leaf_value, std::string mcts_mode) {
update(leaf_value);
if (!is_root()) {
if (mcts_mode == "self_play_mode") {

// printf("parent pointer: %p\n", parent);
// printf("position-ur-1 \n");

// std::lock_guard<std::mutex> lock(mtx); // 自动获取锁
// std::lock_guard<std::recursive_mutex> lock(mtx); // 自动获取锁
// 在 lock_guard 对象析构时自动释放锁


if (mcts_mode == "self_play_mode") {
// printf("position-ur-2 \n");
// printf("leaf_value: %f\n", leaf_value);

update(leaf_value);

// printf("position-ur-3 \n");
if (!is_root()) {
// printf("position-ur-4 \n");
parent->update_recursive(-leaf_value, mcts_mode);
// printf("position-ur-5 \n");
}
else if (mcts_mode == "play_with_bot_mode") {
// printf("position-ur-6 \n");
}
else if (mcts_mode == "play_with_bot_mode") {
update(leaf_value);
if (!is_root()) {
parent->update_recursive(leaf_value, mcts_mode);
}
}
Expand All @@ -43,21 +106,28 @@ class Node {
return children;
}

// std::map<int, std::unique_ptr<Node>> get_children() {
// return children;
// }


int get_visit_count() {
return visit_count;
}

void add_child(int action, Node* node) {
children[action] = node;
}
// void add_child(int action, std::unique_ptr<Node> node) {
// children[action] = std::move(node);
// }

public:
Node* parent;
float prior_p;
int visit_count;
float value_sum;
std::map<int, Node*> children; // or std::vector<Node*>

// private:
// Node* _parent;
// float _prior_p;
// int _visit_count;
// float _value_sum;
// std::map<int, Node*> _children;
std::map<int, Node*> children;
// std::map<int, std::unique_ptr<Node>> children;

};
4 changes: 3 additions & 1 deletion lzero/mcts/ptree/ptree_az.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,14 @@ def get_next_action(
action_visits.append((action, 0))

actions, visits = zip(*action_visits)
print('action_visits= {}'.format(visits))
action_probs = nn.functional.softmax(1.0 / temperature * np.log(torch.as_tensor(visits) + 1e-10), dim=0).numpy()
if sample:
action = np.random.choice(actions, p=action_probs)
else:
action = actions[np.argmax(action_probs)]
# print(action)
print('action= {}'.format(action))
print('action_probs= {}'.format(action_probs))
return action, action_probs

def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn: Callable) -> None:
Expand Down
Loading

0 comments on commit 5372979

Please sign in to comment.