From dd4f95c3c94ff59c79de6d23d3093c65cce704d4 Mon Sep 17 00:00:00 2001 From: luoyao Date: Fri, 20 Dec 2024 15:44:21 +0800 Subject: [PATCH] update llama3 chat server --- src/server/llm/llama/llama3_chat_server.cpp | 71 +++++++++++++-------- 1 file changed, 45 insertions(+), 26 deletions(-) diff --git a/src/server/llm/llama/llama3_chat_server.cpp b/src/server/llm/llama/llama3_chat_server.cpp index 2975134..a3bcbfd 100644 --- a/src/server/llm/llama/llama3_chat_server.cpp +++ b/src/server/llm/llama/llama3_chat_server.cpp @@ -25,7 +25,8 @@ #include "common/status_code.h" #include "common/file_path_util.h" -#include "models/llm/llama/llama3_generator.h" +#include "models/llm/llm_datatype.hpp" +#include "models/llm/llama/llama3.h" namespace jinq { namespace server { @@ -35,9 +36,8 @@ using jinq::common::FilePathUtil; namespace llm { -using models::llm::chat_template::Dialog; -using models::llm::chat_template::ChatMessage; -using models::llm::llama::Llama3Generator; +using models::llm::Dialog; +using LLMPtr = models::llm::llama::Llama3, std::string>; namespace llama { @@ -124,7 +124,7 @@ class Llama3ChatServer::Impl { // server uri std::string _m_server_uri; // llama3 generator - Llama3Generator _m_generator; + std::unique_ptr _m_generator; private: // dialog task @@ -202,9 +202,10 @@ StatusCode Llama3ChatServer::Impl::init(const decltype(toml::parse("")) &config) return StatusCode::SERVER_INIT_FAILED; } auto model_cfg = toml::parse(model_cfg_path); - auto status = _m_generator.init(model_cfg); + _m_generator = std::make_unique(); + auto status = _m_generator->init(model_cfg); if (status != StatusCode::OK) { - LOG(ERROR) << (fmt::format("init llama3 generator failed, status code: {}", std::to_string(status))); + LOG(ERROR) << fmt::format("init llama3 model failed, status code: {}", std::to_string(status)); return StatusCode::SERVER_INIT_FAILED; } @@ -246,15 +247,15 @@ void Llama3ChatServer::Impl::serve_process(WFHttpTask* task) { } // check model stat else if (strcmp(task->get_req()->get_request_uri(), "/check_model_stat") == 0) { - auto model_stat = _m_generator.get_model_stat(); + auto model_stat = _m_generator->get_model_stat(); task->get_resp()->append_output_body(fmt::format( "n_ctx: {}\n kv cache used: {}", model_stat.n_ctx_size, model_stat.kv_cache_cell_nums)); return; } // clear kv cache else if (strcmp(task->get_req()->get_request_uri(), "/clear_kv_cache") == 0) { - _m_generator.clear_kv_cache_cell(); - auto model_stat = _m_generator.get_model_stat(); + _m_generator->clear_kv_cache_cell(); + auto model_stat = _m_generator->get_model_stat(); task->get_resp()->append_output_body(fmt::format( "kv cache cleared.\n n_ctx: {}\n kv cache used: {}", model_stat.n_ctx_size, model_stat.kv_cache_cell_nums)); return; @@ -318,7 +319,7 @@ StatusCode Llama3ChatServer::Impl::parse_request(const protocol::HttpRequest* re doc.Parse(req_body.c_str()); if (!doc.IsObject()) { task.is_valid = false; - LOG(ERROR) << (fmt::format("parse request body failed, invalid json str: {}", req_body)); + LOG(ERROR) << fmt::format("parse request body failed, invalid json str: {}", req_body); return StatusCode::SERVER_RUN_FAILED; } @@ -335,7 +336,7 @@ StatusCode Llama3ChatServer::Impl::parse_request(const protocol::HttpRequest* re for (auto& msg : messages) { auto role = msg["role"].GetString(); auto content = msg["content"].GetString(); - ChatMessage chat_msg = {role, content}; + llama_chat_message chat_msg = {role, content}; task.current_dialog.push_back(chat_msg); } @@ -353,10 +354,10 @@ void Llama3ChatServer::Impl::complete_chat(seriex_ctx* ctx) { Dialog dialog = task.current_dialog; // generate response - auto status = _m_generator.chat_completion(task.current_dialog, ctx->gen_out); + auto status = _m_generator->chat_completion(task.current_dialog, ctx->gen_out); // cache history dialog - ChatMessage msg = {"assistant", ctx->gen_out}; + llama_chat_message msg = {"assistant", ctx->gen_out.c_str()}; dialog.push_back(msg); if (_m_user_history_dialogs.find(task.uuid) != _m_user_history_dialogs.end()) { _m_user_history_dialogs[task.uuid] += dialog; @@ -448,15 +449,27 @@ StatusCode Llama3ChatServer::Impl::regenerate_with_cache_dialogs( // prepare summary dialog Dialog summary_dialogs; auto history_dialogs = _m_user_history_dialogs[task.uuid]; - auto history_dialog_tokens = _m_generator.count_dialog_token_nums(history_dialogs); + std::string fmt_string; + auto status = _m_generator->apply_chat_template(history_dialogs, false, fmt_string); + if (status != StatusCode::OK) { + return status; + } + std::vector fmt_tokens; + status = _m_generator->tokenize(fmt_string, fmt_tokens, true); + if (status != StatusCode::OK) { + return status; + } + auto history_dialog_tokens = fmt_tokens.size(); auto drop_threshold = static_cast(dropped_token_ratio * static_cast(history_dialog_tokens)); - int32_t dropped_token_nums = 0; + size_t dropped_token_nums = 0; int msg_idx =0; for (; msg_idx < history_dialogs.size(); ++msg_idx) { auto role = history_dialogs[msg_idx].role; auto content = history_dialogs[msg_idx].content; Dialog tmp_dia(role, content); - dropped_token_nums += _m_generator.count_dialog_token_nums(tmp_dia); + _m_generator->apply_chat_template(tmp_dia, false, fmt_string); + _m_generator->tokenize(fmt_string, fmt_tokens, true); + dropped_token_nums += fmt_tokens.size(); summary_dialogs += tmp_dia; if (dropped_token_nums >= drop_threshold) { msg_idx++; @@ -468,22 +481,26 @@ StatusCode Llama3ChatServer::Impl::regenerate_with_cache_dialogs( summary_dialogs.push_back({"system", "You are an assistant skilled at generating summaries."}); summary_dialogs.push_back( {"user", fmt::format("Please summarize the multi-turn conversation above " - "in content not exceeding {} tokens.", summary_token_nums)} + "in content not exceeding {} tokens.", summary_token_nums).c_str()} ); // check summary dialog token nums - auto summary_tokens = _m_generator.count_dialog_token_nums(summary_dialogs); - auto n_ctx = _m_generator.get_model_stat().n_ctx_size; + _m_generator->apply_chat_template(summary_dialogs, false, fmt_string); + _m_generator->tokenize(fmt_string, fmt_tokens, true); + auto summary_tokens = static_cast(fmt_tokens.size()); + auto n_ctx = _m_generator->get_model_stat().n_ctx_size; while (summary_tokens > 0.75 * n_ctx) { summary_dialogs.messages.erase(summary_dialogs.messages.begin()); - summary_tokens = _m_generator.count_dialog_token_nums(summary_dialogs); + _m_generator->apply_chat_template(summary_dialogs, false, fmt_string); + _m_generator->tokenize(fmt_string, fmt_tokens, true); + summary_tokens = static_cast(fmt_tokens.size()); } LOG(INFO) << "n_tokens: " << summary_tokens << " used before summary"; // generate summary msg - _m_generator.clear_kv_cache_cell(); + _m_generator->clear_kv_cache_cell(); std::string summary_msg; - auto status = _m_generator.chat_completion(summary_dialogs, summary_msg); + status = _m_generator->chat_completion(summary_dialogs, summary_msg); if (status != StatusCode::OK) { return status; } @@ -496,7 +513,9 @@ StatusCode Llama3ChatServer::Impl::regenerate_with_cache_dialogs( "conversation. Summary content is {}.Please continue assisting the customer based on it.", summary_dialogs.size(), summary_msg) ); - LOG(INFO) << "n_tokens: " << _m_generator.count_dialog_token_nums(updated_dialog) << " used after summary"; + _m_generator->apply_chat_template(updated_dialog, false, fmt_string); + _m_generator->tokenize(fmt_string, fmt_tokens, true); + LOG(INFO) << "n_tokens: " << fmt_tokens.size() << " used after summary"; for (auto i = msg_idx; i < history_dialogs.size(); ++i) { updated_dialog.push_back(history_dialogs[i]); } @@ -504,9 +523,9 @@ StatusCode Llama3ChatServer::Impl::regenerate_with_cache_dialogs( _m_user_history_dialogs[task.uuid] = updated_dialog; // regenerate response content - _m_generator.clear_kv_cache_cell(); + _m_generator->clear_kv_cache_cell(); Dialog cur_dialog = updated_dialog + task.current_dialog; - status = _m_generator.chat_completion(cur_dialog, ctx->gen_out); + status = _m_generator->chat_completion(cur_dialog, ctx->gen_out); // cache dialog _m_user_history_dialogs[task.uuid] += task.current_dialog;