Skip to content

Commit

Permalink
update server
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Dec 25, 2024
1 parent 0902577 commit db3e399
Showing 1 changed file with 29 additions and 29 deletions.
58 changes: 29 additions & 29 deletions src/server/llm/llama/llama3_chat_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ using jinq::common::FilePathUtil;
namespace llm {

using models::llm::Dialog;
using models::llm::ChatMessage;
using LLMPtr = models::llm::llama::Llama3<std::vector<llama_token>&, std::string>;

namespace llama {
Expand Down Expand Up @@ -109,7 +110,7 @@ class Llama3ChatServer::Impl {
int max_connection_nums = 200;
int peer_resp_timeout = 15 * 1000;
int compute_threads = -1;
int handler_threads = 50;
int handler_threads = 25;
size_t request_size_limit = -1;

private:
Expand Down Expand Up @@ -144,7 +145,7 @@ class Llama3ChatServer::Impl {
std::string task_finished_ts;
bool is_task_req_valid = false;
std::string gen_out;
dialog_task d_task;
dialog_task* d_task = nullptr;
};
// dialog cache
std::unordered_map<std::string, Dialog> _m_user_history_dialogs;
Expand All @@ -156,7 +157,7 @@ class Llama3ChatServer::Impl {
* @param task
* @return
*/
static StatusCode parse_request(const protocol::HttpRequest* req, dialog_task& task);
static StatusCode parse_request(const protocol::HttpRequest* req, dialog_task* task);

/***
*
Expand Down Expand Up @@ -266,9 +267,9 @@ void Llama3ChatServer::Impl::serve_process(WFHttpTask* task) {
// parse request body
auto* req = task->get_req();
auto* resp = task->get_resp();
dialog_task d_task{};
auto* d_task = new dialog_task;
parse_request(req, d_task);
if (!d_task.is_valid) {
if (!d_task->is_valid) {
task->get_resp()->append_output_body(fmt::format("invalid request data: {}", protocol::HttpUtil::decode_chunked_body(req)));
return;
}
Expand Down Expand Up @@ -306,39 +307,38 @@ void Llama3ChatServer::Impl::serve_process(WFHttpTask* task) {
* @param task
* @return
*/
StatusCode Llama3ChatServer::Impl::parse_request(const protocol::HttpRequest* req, dialog_task& task) {
StatusCode Llama3ChatServer::Impl::parse_request(const protocol::HttpRequest* req, dialog_task* task) {
// set task uuid
protocol::HttpHeaderMap map(req);
if (!map.key_exists("cookie")) {
task.uuid = server_internal_impl::generate_uuid();
task->uuid = server_internal_impl::generate_uuid();
} else {
task.uuid = map.get("cookie");
task->uuid = map.get("cookie");
}

std::string req_body = protocol::HttpUtil::decode_chunked_body(req);
rapidjson::Document doc;
doc.Parse(req_body.c_str());
if (!doc.IsObject()) {
task.is_valid = false;
task->is_valid = false;
LOG(ERROR) << fmt::format("parse request body failed, invalid json str: {}", req_body);
return StatusCode::SERVER_RUN_FAILED;
}

if (doc.HasMember("task_id")) {
task.task_id = doc["task_id"].GetString();
task->task_id = doc["task_id"].GetString();
}

if (!doc.HasMember("data")) {
task.is_valid = false;
task->is_valid = false;
LOG(ERROR) << (fmt::format("invalid json str: {}, missing \"data\" field", req_body));
return StatusCode::SERVER_RUN_FAILED;
}
auto messages = doc["data"].GetArray();
for (auto& msg : messages) {
auto role = msg["role"].GetString();
auto content = msg["content"].GetString();
llama_chat_message chat_msg = {role, content};
task.current_dialog.push_back(chat_msg);
std::string role = msg["role"].GetString();
std::string content = msg["content"].GetString();
task->current_dialog.push_back({role, content});
}

return StatusCode::OK;
Expand All @@ -352,18 +352,18 @@ StatusCode Llama3ChatServer::Impl::parse_request(const protocol::HttpRequest* re
void Llama3ChatServer::Impl::complete_chat(seriex_ctx* ctx) {
// fetch current dialog
auto task = ctx->d_task;
Dialog dialog = task.current_dialog;
Dialog dialog = task->current_dialog;

// generate response
auto status = _m_generator->chat_completion(task.current_dialog, ctx->gen_out);
auto status = _m_generator->chat_completion(task->current_dialog, ctx->gen_out);

// cache history dialog
llama_chat_message msg = {"assistant", ctx->gen_out.c_str()};
ChatMessage msg = {"assistant", ctx->gen_out};
dialog.push_back(msg);
if (_m_user_history_dialogs.find(task.uuid) != _m_user_history_dialogs.end()) {
_m_user_history_dialogs[task.uuid] += dialog;
if (_m_user_history_dialogs.find(task->uuid) != _m_user_history_dialogs.end()) {
_m_user_history_dialogs[task->uuid] += dialog;
} else {
_m_user_history_dialogs.insert(std::make_pair(task.uuid, dialog));
_m_user_history_dialogs.insert(std::make_pair(task->uuid, dialog));
}

// check if context exceeded occurred
Expand Down Expand Up @@ -430,7 +430,7 @@ void Llama3ChatServer::Impl::complete_chat_cb(const WFGoTask* task) {

auto response_body = buffer.GetString();
ctx->response->append_output_body(response_body);
ctx->response->add_header_pair("Set-Cookie", ctx->d_task.uuid);
ctx->response->add_header_pair("Set-Cookie", ctx->d_task->uuid);

// update task count
_m_finished_jobs++;
Expand All @@ -449,7 +449,7 @@ StatusCode Llama3ChatServer::Impl::regenerate_with_cache_dialogs(
auto task = ctx->d_task;
// prepare summary dialog
Dialog summary_dialogs;
auto history_dialogs = _m_user_history_dialogs[task.uuid];
auto history_dialogs = _m_user_history_dialogs[task->uuid];
std::string fmt_string;
auto status = _m_generator->apply_chat_template(history_dialogs, false, fmt_string);
if (status != StatusCode::OK) {
Expand Down Expand Up @@ -482,7 +482,7 @@ StatusCode Llama3ChatServer::Impl::regenerate_with_cache_dialogs(
summary_dialogs.push_back({"system", "You are an assistant skilled at generating summaries."});
summary_dialogs.push_back(
{"user", fmt::format("Please summarize the multi-turn conversation above "
"in content not exceeding {} tokens.", summary_token_nums).c_str()}
"in content not exceeding {} tokens.", summary_token_nums)}
);

// check summary dialog token nums
Expand Down Expand Up @@ -520,17 +520,17 @@ StatusCode Llama3ChatServer::Impl::regenerate_with_cache_dialogs(
for (auto i = msg_idx; i < history_dialogs.size(); ++i) {
updated_dialog.push_back(history_dialogs[i]);
}
_m_user_history_dialogs[task.uuid].clean_cache();
_m_user_history_dialogs[task.uuid] = updated_dialog;
_m_user_history_dialogs[task->uuid].clean_cache();
_m_user_history_dialogs[task->uuid] = updated_dialog;

// regenerate response content
_m_generator->clear_kv_cache_cell();
Dialog cur_dialog = updated_dialog + task.current_dialog;
Dialog cur_dialog = updated_dialog + task->current_dialog;
status = _m_generator->chat_completion(cur_dialog, ctx->gen_out);

// cache dialog
_m_user_history_dialogs[task.uuid] += task.current_dialog;
_m_user_history_dialogs[task.uuid] += Dialog("assistant", ctx->gen_out);
_m_user_history_dialogs[task->uuid] += task->current_dialog;
_m_user_history_dialogs[task->uuid] += Dialog("assistant", ctx->gen_out);

return status;
}
Expand Down

0 comments on commit db3e399

Please sign in to comment.