Skip to content

Commit

Permalink
chore: some adjustments. ➕
Browse files Browse the repository at this point in the history
  • Loading branch information
Joker2770 committed Apr 26, 2023
1 parent a2d59bc commit 582585e
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 44 deletions.
19 changes: 10 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,18 @@ target_link_libraries(pbrain-Z2I onnxruntime)

add_subdirectory(test)

if (MSVC)
file(GLOB ONNX_DLLS "${ONNXRUNTIME_ROOTDIR}/lib/*.dll")
add_custom_command(TARGET pbrain-Z2I POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different
${ONNX_DLLS} $<TARGET_FILE_DIR:pbrain-Z2I>)
endif (MSVC)
if(MSVC)
file(GLOB ONNX_DLLS "${ONNXRUNTIME_ROOTDIR}/lib/*.dll")
elseif(UNIX)
file(GLOB ONNX_DLLS "${ONNXRUNTIME_ROOTDIR}/lib/lib*.so*")
endif()

if (UNIX)
file(GLOB ONNX_DLLS "${ONNXRUNTIME_ROOTDIR}/lib/lib*.so*")
add_custom_command(TARGET pbrain-Z2I POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different
if(DEFINED ONNX_DLLS)
add_custom_command(TARGET pbrain-Z2I POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different
${ONNX_DLLS} $<TARGET_FILE_DIR:pbrain-Z2I>)
endif (UNIX)
else()
message("ONNX_DLLS does not exist or is empty.")
endif()

# ADD_EXECUTABLE(AlphaZeroInference main.cpp)

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ CMakefiles.txt: convert the onnxruntime path to your own path

```shell
mkdir build
cp *.sh ./build/
cp ./scripts/*.sh ./build/
cd ./build
cmake .. # (or "cmake -A x64 ..")
cmake --build . --config Release # (or open .sln file through visual Studio 19 and generate for win10)
Expand Down
31 changes: 31 additions & 0 deletions scripts/train.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
@echo off
set n=1
set batch_num=1
set do_prepare=0

if %do_prepare%==1 (
echo preparing........
train_net.bat prepare
python ..\\python\\learner.py
) else (
echo skip prepare!
)
echo start train......

:LOOP
if %%n LEQ 2000 (
echo --------------%%n-th train------------------
set i=0
:INNER_LOOP
if %%i LSS %batch_num% (
echo generate data......
start /B cmd /C "train_net.bat generate %%i"
set /A i+=1
goto INNER_LOOP
)
python ..\\python\\learner.py train
:: train_net.bat eval_with_winner 10
train_net.bat eval_with_random 10
set /A n+=1
goto LOOP
)
2 changes: 1 addition & 1 deletion train.sh → scripts/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ do
# bash ./train_net.sh eval_with_winner 10
bash ./train_net.sh eval_with_random 10
let n++
done
done
2 changes: 1 addition & 1 deletion train_miniconda3.sh → scripts/train_miniconda3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ do
# bash ./train_net.sh eval_with_winner 10
bash ./train_net.sh eval_with_random 10
let n++
done
done
1 change: 1 addition & 0 deletions scripts/train_net.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.\\train_eval_net.exe %1 %2
1 change: 1 addition & 0 deletions scripts/train_net.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
exec ./train_eval_net $1 $2
2 changes: 1 addition & 1 deletion src/pbrain-Z2I/pbrain-Z2I.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ using namespace std;

bool isNumericString(const char *str, unsigned int i_len)
{
for (int i = 0; i < i_len; i++)
for (unsigned int i = 0; i < i_len; i++)
{
if (!isdigit(str[i]))
{
Expand Down
42 changes: 21 additions & 21 deletions src/train_and_eval/train_eval_net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ void generate_data_for_train(int current_weight, int start_batch_id)
SelfPlay *sp = new SelfPlay(model);
sp->self_play_for_train(NUM_TRAIN_THREADS, start_batch_id);

if (nullptr != model)
{
delete model;
model = nullptr;
if (nullptr != model)
{
delete model;
model = nullptr;
}
if (nullptr != sp)
{
delete sp;
sp = nullptr;
if (nullptr != sp)
{
delete sp;
sp = nullptr;
}
}

Expand Down Expand Up @@ -140,20 +140,20 @@ vector<int> eval(int weight_a, int weight_b, unsigned int game_num, int a_sims,
}
// cout << "win_table = " << win_table[0] << win_table[1] << win_table [2] << endl;

if (nullptr != nn_a)
{
delete nn_a;
nn_a = nullptr;
if (nullptr != nn_a)
{
delete nn_a;
nn_a = nullptr;
}
if (nullptr != nn_b)
{
delete nn_b;
nn_b = nullptr;
if (nullptr != nn_b)
{
delete nn_b;
nn_b = nullptr;
}
if (nullptr != thread_pool)
{
delete thread_pool;
thread_pool = nullptr;
if (nullptr != thread_pool)
{
delete thread_pool;
thread_pool = nullptr;
}

return {win_table[0], win_table[1], win_table[2]};
Expand Down Expand Up @@ -199,7 +199,7 @@ int main(int argc, char *argv[])
}
// logger_reader >> temp[1];
logger_reader.close();
cout << "Generating... current_weight = " << current_weight << endl;
cout << "Generating... current_weight = " << current_weight << " start batch id: " << argv[2] << endl;
generate_data_for_train(current_weight, atoi(argv[2]) * NUM_TRAIN_THREADS);
}
else if (strcmp(argv[1], "eval_with_winner") == 0)
Expand Down
19 changes: 10 additions & 9 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ target_link_libraries(mcts_test onnxruntime)
target_link_libraries(get_best_action_from_prob_test Z2I_lib)
target_link_libraries(get_best_action_from_prob_test onnxruntime)

if (MSVC)
file(GLOB ONNX_DLLS "${ONNXRUNTIME_ROOTDIR}/lib/*.dll")
add_custom_command(TARGET mcts_test POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different
${ONNX_DLLS} $<TARGET_FILE_DIR:mcts_test>)
endif (MSVC)
if(MSVC)
file(GLOB ONNX_DLLS "${ONNXRUNTIME_ROOTDIR}/lib/*.dll")
elseif(UNIX)
file(GLOB ONNX_DLLS "${ONNXRUNTIME_ROOTDIR}/lib/lib*.so*")
endif()

if (UNIX)
file(GLOB ONNX_DLLS "${ONNXRUNTIME_ROOTDIR}/lib/lib*.so*")
add_custom_command(TARGET mcts_test POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different
if(DEFINED ONNX_DLLS)
add_custom_command(TARGET mcts_test POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different
${ONNX_DLLS} $<TARGET_FILE_DIR:mcts_test>)
endif (UNIX)
else()
message("ONNX_DLLS does not exist or is empty.")
endif()

# add_test(NAME mcts_test COMMAND "mcts_test")
1 change: 0 additions & 1 deletion train_net.sh

This file was deleted.

0 comments on commit 582585e

Please sign in to comment.