diff --git a/trackalert/trackalert-luastate.hh b/trackalert/trackalert-luastate.hh index 69bd6399..3a9972e0 100644 --- a/trackalert/trackalert-luastate.hh +++ b/trackalert/trackalert-luastate.hh @@ -21,6 +21,7 @@ */ #pragma once + #include "ext/luawrapper/include/LuaContext.hpp" #include "misc.hh" #include @@ -42,7 +43,9 @@ typedef std::function background_t; extern background_t g_background; typedef std::unordered_map bg_func_map_t; -vector> setupLua(bool client, bool allow_report, LuaContext& c_lua, report_t& report_func, bg_func_map_t* bg_func_map, CustomFuncMap& custom_func_map, const std::string& config); +vector> +setupLua(bool client, bool allow_report, LuaContext& c_lua, report_t& report_func, bg_func_map_t* bg_func_map, + CustomFuncMap& custom_func_map, const std::string& config); struct LuaThreadContext { LuaContext lua_context; @@ -54,40 +57,46 @@ struct LuaThreadContext { #define NUM_LUA_STATES 6 -class LuaMultiThread -{ +class LuaMultiThread { public: - LuaMultiThread() : num_states(NUM_LUA_STATES), - state_index(0) + LuaMultiThread() : num_states(NUM_LUA_STATES) { LuaMultiThread{num_states}; } - LuaMultiThread(unsigned int nstates) : num_states(nstates), - state_index(0) + LuaMultiThread(unsigned int nstates) : num_states(nstates) { - for (unsigned int i=0; i()); - } + for (unsigned int i = 0; i < num_states; i++) { + lua_pool.push_back(std::make_shared()); + } + lua_read_only = lua_pool; // Make a copy for use by the control thread } LuaMultiThread(const LuaMultiThread&) = delete; + LuaMultiThread& operator=(const LuaMultiThread&) = delete; - + // these are used to setup the function pointers - std::vector>::iterator begin() { return lua_cv.begin(); } - std::vector>::iterator end() { return lua_cv.end(); } + std::vector>::iterator begin() + { return lua_read_only.begin(); } - void report(const LoginTuple& lt) { - auto lt_context = getLuaState(); + std::vector>::iterator end() + { return lua_read_only.end(); } + + void report(const LoginTuple& lt) + { + auto pool_member = getPoolMember(); + auto lt_context = pool_member.getLuaContext(); // lock the lua state mutex std::lock_guard lock(lt_context->lua_mutex); // call the report function lt_context->report_func(lt); } - void background(const std::string& func_name) { - auto lt_context = getLuaState(); + void background(const std::string& func_name) + { + auto pool_member = getPoolMember(); + auto lt_context = pool_member.getLuaContext(); // lock the lua state mutex std::lock_guard lock(lt_context->lua_mutex); // call the background function @@ -96,31 +105,49 @@ public: fn->second(); } - CustomFuncReturn custom_func(const std::string& command, const CustomFuncArgs& cfa) { - auto lt_context = getLuaState(); + CustomFuncReturn custom_func(const std::string& command, const CustomFuncArgs& cfa) + { + auto pool_member = getPoolMember(); + auto lt_context = pool_member.getLuaContext(); // lock the lua state mutex std::lock_guard lock(lt_context->lua_mutex); // call the custom function - for (const auto& i : lt_context->custom_func_map) { + for (const auto& i: lt_context->custom_func_map) { if (command.compare(i.first) == 0) { - return i.second.c_func(cfa); + return i.second.c_func(cfa); } } return CustomFuncReturn(false, KeyValVector{}); } - + protected: - std::shared_ptr getLuaState() - { + class SharedPoolMember { + public: + SharedPoolMember(std::shared_ptr ptr, LuaMultiThread* pool) : d_pool_item(ptr), d_pool(pool) {} + ~SharedPoolMember() { if (d_pool != nullptr) { d_pool->returnPoolMember(d_pool_item); } } + SharedPoolMember(const SharedPoolMember&) = delete; + SharedPoolMember& operator=(const SharedPoolMember&) = delete; + std::shared_ptr getLuaContext() { return d_pool_item; } + private: + std::shared_ptr d_pool_item; + LuaMultiThread* d_pool; + }; + + SharedPoolMember getPoolMember() { std::lock_guard lock(mutx); - if (state_index >= num_states) - state_index = 0; - return lua_cv[state_index++]; + auto member = lua_pool.back(); + lua_pool.pop_back(); + return SharedPoolMember(member, this); } + void returnPoolMember(std::shared_ptr my_ptr) { + std::lock_guard lock(mutx); + lua_pool.push_back(my_ptr); + } + private: - std::vector> lua_cv; + std::vector> lua_pool; + std::vector> lua_read_only; unsigned int num_states; - unsigned int state_index; std::mutex mutx; }; diff --git a/wforce/luastate.hh b/wforce/luastate.hh index d6dcc925..b8d44cd6 100644 --- a/wforce/luastate.hh +++ b/wforce/luastate.hh @@ -21,6 +21,7 @@ */ #pragma once + #include "ext/luawrapper/include/LuaContext.hpp" #include "misc.hh" #include @@ -41,7 +42,7 @@ typedef std::function canonicalize_t; struct CustomFuncMapObject { custom_func_t c_func; - bool c_reportSink; + bool c_reportSink; }; typedef std::map CustomFuncMap; @@ -50,7 +51,10 @@ extern CustomFuncMap g_custom_func_map; typedef std::map CustomGetFuncMap; extern CustomGetFuncMap g_custom_get_func_map; -vector> setupLua(bool client, bool allow_report, LuaContext& c_lua, allow_t& allow_func, report_t& report_func, reset_t& reset_func, canonicalize_t& canon_func, CustomFuncMap& custom_func_map, CustomGetFuncMap& custom_get_func_map, const std::string& config); +vector> +setupLua(bool client, bool allow_report, LuaContext& c_lua, allow_t& allow_func, report_t& report_func, + reset_t& reset_func, canonicalize_t& canon_func, CustomFuncMap& custom_func_map, + CustomGetFuncMap& custom_get_func_map, const std::string& config); struct LuaThreadContext { LuaContext lua_context; @@ -65,101 +69,132 @@ struct LuaThreadContext { #define NUM_LUA_STATES 6 -class LuaMultiThread -{ +class LuaMultiThread { public: - LuaMultiThread() : num_states(NUM_LUA_STATES), - state_index(0) + + LuaMultiThread() : num_states(NUM_LUA_STATES) { LuaMultiThread{num_states}; } - LuaMultiThread(unsigned int nstates) : num_states(nstates), - state_index(0) + LuaMultiThread(unsigned int nstates) : num_states(nstates) { - for (unsigned int i=0; i()); - } + for (unsigned int i = 0; i < num_states; i++) { + lua_pool.push_back(std::make_shared()); + } + lua_read_only = lua_pool; // Make a copy for use by the control thread } LuaMultiThread(const LuaMultiThread&) = delete; + LuaMultiThread& operator=(const LuaMultiThread&) = delete; // these are used to setup the allow and report function pointers - std::vector>::iterator begin() { return lua_cv.begin(); } - std::vector>::iterator end() { return lua_cv.end(); } + std::vector>::iterator begin() + { return lua_read_only.begin(); } + + std::vector>::iterator end() + { return lua_read_only.end(); } - bool reset(const std::string& type, const std::string& login_value, const ComboAddress& ca_value) { - auto lt_context = getLuaState(); + bool reset(const std::string& type, const std::string& login_value, const ComboAddress& ca_value) + { + auto pool_member = getPoolMember(); + auto lt_context = pool_member.getLuaContext(); // lock the lua state mutex std::lock_guard lock(lt_context->lua_mutex); // call the reset function return lt_context->reset_func(type, login_value, ca_value); } - AllowReturn allow(const LoginTuple& lt) { - auto lt_context = getLuaState(); + AllowReturn allow(const LoginTuple& lt) + { + auto pool_member = getPoolMember(); + auto lt_context = pool_member.getLuaContext(); // lock the lua state mutex std::lock_guard lock(lt_context->lua_mutex); // call the allow function return lt_context->allow_func(lt); } - void report(const LoginTuple& lt) { - auto lt_context = getLuaState(); + void report(const LoginTuple& lt) + { + auto pool_member = getPoolMember(); + auto lt_context = pool_member.getLuaContext(); // lock the lua state mutex std::lock_guard lock(lt_context->lua_mutex); // call the report function lt_context->report_func(lt); } - std::string canonicalize(const std::string& login) { - auto lt_context = getLuaState(); + std::string canonicalize(const std::string& login) + { + auto pool_member = getPoolMember(); + auto lt_context = pool_member.getLuaContext(); // lock the lua state mutex std::lock_guard lock(lt_context->lua_mutex); // call the canonicalize function return lt_context->canon_func(login); } - CustomFuncReturn custom_func(const std::string& command, const CustomFuncArgs& cfa, bool& reportSinkReturn) { - auto lt_context = getLuaState(); + CustomFuncReturn custom_func(const std::string& command, const CustomFuncArgs& cfa, bool& reportSinkReturn) + { + auto pool_member = getPoolMember(); + auto lt_context = pool_member.getLuaContext(); // lock the lua state mutex std::lock_guard lock(lt_context->lua_mutex); // call the custom function - for (const auto& i : lt_context->custom_func_map) { + for (const auto& i: lt_context->custom_func_map) { if (command.compare(i.first) == 0) { - reportSinkReturn = i.second.c_reportSink; - return i.second.c_func(cfa); + reportSinkReturn = i.second.c_reportSink; + return i.second.c_func(cfa); } } return CustomFuncReturn(false, KeyValVector{}); } - std::string custom_get_func(const std::string& command) { - auto lt_context = getLuaState(); + std::string custom_get_func(const std::string& command) + { + auto pool_member = getPoolMember(); + auto lt_context = pool_member.getLuaContext(); // lock the lua state mutex std::lock_guard lock(lt_context->lua_mutex); // call the custom function - for (const auto& i : lt_context->custom_get_func_map) { + for (const auto& i: lt_context->custom_get_func_map) { if (command.compare(i.first) == 0) { - return i.second(); + return i.second(); } } return string(); } - + protected: - std::shared_ptr getLuaState() - { + + class SharedPoolMember { + public: + SharedPoolMember(std::shared_ptr ptr, LuaMultiThread* pool) : d_pool_item(ptr), d_pool(pool) {} + ~SharedPoolMember() { if (d_pool != nullptr) { d_pool->returnPoolMember(d_pool_item); } } + SharedPoolMember(const SharedPoolMember&) = delete; + SharedPoolMember& operator=(const SharedPoolMember&) = delete; + std::shared_ptr getLuaContext() { return d_pool_item; } + private: + std::shared_ptr d_pool_item; + LuaMultiThread* d_pool; + }; + SharedPoolMember getPoolMember() { + std::lock_guard lock(mutx); + auto member = lua_pool.back(); + lua_pool.pop_back(); + return SharedPoolMember(member, this); + } + void returnPoolMember(std::shared_ptr my_ptr) { std::lock_guard lock(mutx); - if (state_index >= num_states) - state_index = 0; - return lua_cv[state_index++]; + lua_pool.push_back(my_ptr); } + private: - std::vector> lua_cv; + std::vector> lua_pool; + std::vector> lua_read_only; unsigned int num_states; - unsigned int state_index; std::mutex mutx; };