diff --git a/src/server/llm/llama/llama3_chat_server.cpp b/src/server/llm/llama/llama3_chat_server.cpp index eee203e..2a01cfb 100644 --- a/src/server/llm/llama/llama3_chat_server.cpp +++ b/src/server/llm/llama/llama3_chat_server.cpp @@ -37,6 +37,7 @@ using jinq::common::FilePathUtil; namespace llm { using models::llm::Dialog; +using models::llm::ChatMessage; using LLMPtr = models::llm::llama::Llama3&, std::string>; namespace llama { @@ -109,7 +110,7 @@ class Llama3ChatServer::Impl { int max_connection_nums = 200; int peer_resp_timeout = 15 * 1000; int compute_threads = -1; - int handler_threads = 50; + int handler_threads = 25; size_t request_size_limit = -1; private: @@ -144,7 +145,7 @@ class Llama3ChatServer::Impl { std::string task_finished_ts; bool is_task_req_valid = false; std::string gen_out; - dialog_task d_task; + dialog_task* d_task = nullptr; }; // dialog cache std::unordered_map _m_user_history_dialogs; @@ -156,7 +157,7 @@ class Llama3ChatServer::Impl { * @param task * @return */ - static StatusCode parse_request(const protocol::HttpRequest* req, dialog_task& task); + static StatusCode parse_request(const protocol::HttpRequest* req, dialog_task* task); /*** * @@ -266,9 +267,9 @@ void Llama3ChatServer::Impl::serve_process(WFHttpTask* task) { // parse request body auto* req = task->get_req(); auto* resp = task->get_resp(); - dialog_task d_task{}; + auto* d_task = new dialog_task; parse_request(req, d_task); - if (!d_task.is_valid) { + if (!d_task->is_valid) { task->get_resp()->append_output_body(fmt::format("invalid request data: {}", protocol::HttpUtil::decode_chunked_body(req))); return; } @@ -306,39 +307,38 @@ void Llama3ChatServer::Impl::serve_process(WFHttpTask* task) { * @param task * @return */ -StatusCode Llama3ChatServer::Impl::parse_request(const protocol::HttpRequest* req, dialog_task& task) { +StatusCode Llama3ChatServer::Impl::parse_request(const protocol::HttpRequest* req, dialog_task* task) { // set task uuid protocol::HttpHeaderMap map(req); if (!map.key_exists("cookie")) { - task.uuid = server_internal_impl::generate_uuid(); + task->uuid = server_internal_impl::generate_uuid(); } else { - task.uuid = map.get("cookie"); + task->uuid = map.get("cookie"); } std::string req_body = protocol::HttpUtil::decode_chunked_body(req); rapidjson::Document doc; doc.Parse(req_body.c_str()); if (!doc.IsObject()) { - task.is_valid = false; + task->is_valid = false; LOG(ERROR) << fmt::format("parse request body failed, invalid json str: {}", req_body); return StatusCode::SERVER_RUN_FAILED; } if (doc.HasMember("task_id")) { - task.task_id = doc["task_id"].GetString(); + task->task_id = doc["task_id"].GetString(); } if (!doc.HasMember("data")) { - task.is_valid = false; + task->is_valid = false; LOG(ERROR) << (fmt::format("invalid json str: {}, missing \"data\" field", req_body)); return StatusCode::SERVER_RUN_FAILED; } auto messages = doc["data"].GetArray(); for (auto& msg : messages) { - auto role = msg["role"].GetString(); - auto content = msg["content"].GetString(); - llama_chat_message chat_msg = {role, content}; - task.current_dialog.push_back(chat_msg); + std::string role = msg["role"].GetString(); + std::string content = msg["content"].GetString(); + task->current_dialog.push_back({role, content}); } return StatusCode::OK; @@ -352,18 +352,18 @@ StatusCode Llama3ChatServer::Impl::parse_request(const protocol::HttpRequest* re void Llama3ChatServer::Impl::complete_chat(seriex_ctx* ctx) { // fetch current dialog auto task = ctx->d_task; - Dialog dialog = task.current_dialog; + Dialog dialog = task->current_dialog; // generate response - auto status = _m_generator->chat_completion(task.current_dialog, ctx->gen_out); + auto status = _m_generator->chat_completion(task->current_dialog, ctx->gen_out); // cache history dialog - llama_chat_message msg = {"assistant", ctx->gen_out.c_str()}; + ChatMessage msg = {"assistant", ctx->gen_out}; dialog.push_back(msg); - if (_m_user_history_dialogs.find(task.uuid) != _m_user_history_dialogs.end()) { - _m_user_history_dialogs[task.uuid] += dialog; + if (_m_user_history_dialogs.find(task->uuid) != _m_user_history_dialogs.end()) { + _m_user_history_dialogs[task->uuid] += dialog; } else { - _m_user_history_dialogs.insert(std::make_pair(task.uuid, dialog)); + _m_user_history_dialogs.insert(std::make_pair(task->uuid, dialog)); } // check if context exceeded occurred @@ -430,7 +430,7 @@ void Llama3ChatServer::Impl::complete_chat_cb(const WFGoTask* task) { auto response_body = buffer.GetString(); ctx->response->append_output_body(response_body); - ctx->response->add_header_pair("Set-Cookie", ctx->d_task.uuid); + ctx->response->add_header_pair("Set-Cookie", ctx->d_task->uuid); // update task count _m_finished_jobs++; @@ -449,7 +449,7 @@ StatusCode Llama3ChatServer::Impl::regenerate_with_cache_dialogs( auto task = ctx->d_task; // prepare summary dialog Dialog summary_dialogs; - auto history_dialogs = _m_user_history_dialogs[task.uuid]; + auto history_dialogs = _m_user_history_dialogs[task->uuid]; std::string fmt_string; auto status = _m_generator->apply_chat_template(history_dialogs, false, fmt_string); if (status != StatusCode::OK) { @@ -482,7 +482,7 @@ StatusCode Llama3ChatServer::Impl::regenerate_with_cache_dialogs( summary_dialogs.push_back({"system", "You are an assistant skilled at generating summaries."}); summary_dialogs.push_back( {"user", fmt::format("Please summarize the multi-turn conversation above " - "in content not exceeding {} tokens.", summary_token_nums).c_str()} + "in content not exceeding {} tokens.", summary_token_nums)} ); // check summary dialog token nums @@ -520,17 +520,17 @@ StatusCode Llama3ChatServer::Impl::regenerate_with_cache_dialogs( for (auto i = msg_idx; i < history_dialogs.size(); ++i) { updated_dialog.push_back(history_dialogs[i]); } - _m_user_history_dialogs[task.uuid].clean_cache(); - _m_user_history_dialogs[task.uuid] = updated_dialog; + _m_user_history_dialogs[task->uuid].clean_cache(); + _m_user_history_dialogs[task->uuid] = updated_dialog; // regenerate response content _m_generator->clear_kv_cache_cell(); - Dialog cur_dialog = updated_dialog + task.current_dialog; + Dialog cur_dialog = updated_dialog + task->current_dialog; status = _m_generator->chat_completion(cur_dialog, ctx->gen_out); // cache dialog - _m_user_history_dialogs[task.uuid] += task.current_dialog; - _m_user_history_dialogs[task.uuid] += Dialog("assistant", ctx->gen_out); + _m_user_history_dialogs[task->uuid] += task->current_dialog; + _m_user_history_dialogs[task->uuid] += Dialog("assistant", ctx->gen_out); return status; }