Skip to content

Commit

Permalink
update llama3 chat server
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Dec 20, 2024
1 parent 9f5135c commit dd4f95c
Showing 1 changed file with 45 additions and 26 deletions.
71 changes: 45 additions & 26 deletions src/server/llm/llama/llama3_chat_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@

#include "common/status_code.h"
#include "common/file_path_util.h"
#include "models/llm/llama/llama3_generator.h"
#include "models/llm/llm_datatype.hpp"
#include "models/llm/llama/llama3.h"

namespace jinq {
namespace server {
Expand All @@ -35,9 +36,8 @@ using jinq::common::FilePathUtil;

namespace llm {

using models::llm::chat_template::Dialog;
using models::llm::chat_template::ChatMessage;
using models::llm::llama::Llama3Generator;
using models::llm::Dialog;
using LLMPtr = models::llm::llama::Llama3<std::vector<llama_token>, std::string>;

namespace llama {

Expand Down Expand Up @@ -124,7 +124,7 @@ class Llama3ChatServer::Impl {
// server uri
std::string _m_server_uri;
// llama3 generator
Llama3Generator _m_generator;
std::unique_ptr<LLMPtr > _m_generator;

private:
// dialog task
Expand Down Expand Up @@ -202,9 +202,10 @@ StatusCode Llama3ChatServer::Impl::init(const decltype(toml::parse("")) &config)
return StatusCode::SERVER_INIT_FAILED;
}
auto model_cfg = toml::parse(model_cfg_path);
auto status = _m_generator.init(model_cfg);
_m_generator = std::make_unique<LLMPtr>();
auto status = _m_generator->init(model_cfg);
if (status != StatusCode::OK) {
LOG(ERROR) << (fmt::format("init llama3 generator failed, status code: {}", std::to_string(status)));
LOG(ERROR) << fmt::format("init llama3 model failed, status code: {}", std::to_string(status));
return StatusCode::SERVER_INIT_FAILED;
}

Expand Down Expand Up @@ -246,15 +247,15 @@ void Llama3ChatServer::Impl::serve_process(WFHttpTask* task) {
}
// check model stat
else if (strcmp(task->get_req()->get_request_uri(), "/check_model_stat") == 0) {
auto model_stat = _m_generator.get_model_stat();
auto model_stat = _m_generator->get_model_stat();
task->get_resp()->append_output_body(fmt::format(
"<html>n_ctx: {}\n kv cache used: {}</html>", model_stat.n_ctx_size, model_stat.kv_cache_cell_nums));
return;
}
// clear kv cache
else if (strcmp(task->get_req()->get_request_uri(), "/clear_kv_cache") == 0) {
_m_generator.clear_kv_cache_cell();
auto model_stat = _m_generator.get_model_stat();
_m_generator->clear_kv_cache_cell();
auto model_stat = _m_generator->get_model_stat();
task->get_resp()->append_output_body(fmt::format(
"<html>kv cache cleared.\n n_ctx: {}\n kv cache used: {}</html>", model_stat.n_ctx_size, model_stat.kv_cache_cell_nums));
return;
Expand Down Expand Up @@ -318,7 +319,7 @@ StatusCode Llama3ChatServer::Impl::parse_request(const protocol::HttpRequest* re
doc.Parse(req_body.c_str());
if (!doc.IsObject()) {
task.is_valid = false;
LOG(ERROR) << (fmt::format("parse request body failed, invalid json str: {}", req_body));
LOG(ERROR) << fmt::format("parse request body failed, invalid json str: {}", req_body);
return StatusCode::SERVER_RUN_FAILED;
}

Expand All @@ -335,7 +336,7 @@ StatusCode Llama3ChatServer::Impl::parse_request(const protocol::HttpRequest* re
for (auto& msg : messages) {
auto role = msg["role"].GetString();
auto content = msg["content"].GetString();
ChatMessage chat_msg = {role, content};
llama_chat_message chat_msg = {role, content};
task.current_dialog.push_back(chat_msg);
}

Expand All @@ -353,10 +354,10 @@ void Llama3ChatServer::Impl::complete_chat(seriex_ctx* ctx) {
Dialog dialog = task.current_dialog;

// generate response
auto status = _m_generator.chat_completion(task.current_dialog, ctx->gen_out);
auto status = _m_generator->chat_completion(task.current_dialog, ctx->gen_out);

// cache history dialog
ChatMessage msg = {"assistant", ctx->gen_out};
llama_chat_message msg = {"assistant", ctx->gen_out.c_str()};
dialog.push_back(msg);
if (_m_user_history_dialogs.find(task.uuid) != _m_user_history_dialogs.end()) {
_m_user_history_dialogs[task.uuid] += dialog;
Expand Down Expand Up @@ -448,15 +449,27 @@ StatusCode Llama3ChatServer::Impl::regenerate_with_cache_dialogs(
// prepare summary dialog
Dialog summary_dialogs;
auto history_dialogs = _m_user_history_dialogs[task.uuid];
auto history_dialog_tokens = _m_generator.count_dialog_token_nums(history_dialogs);
std::string fmt_string;
auto status = _m_generator->apply_chat_template(history_dialogs, false, fmt_string);
if (status != StatusCode::OK) {
return status;
}
std::vector<llama_token > fmt_tokens;
status = _m_generator->tokenize(fmt_string, fmt_tokens, true);
if (status != StatusCode::OK) {
return status;
}
auto history_dialog_tokens = fmt_tokens.size();
auto drop_threshold = static_cast<int32_t >(dropped_token_ratio * static_cast<float>(history_dialog_tokens));
int32_t dropped_token_nums = 0;
size_t dropped_token_nums = 0;
int msg_idx =0;
for (; msg_idx < history_dialogs.size(); ++msg_idx) {
auto role = history_dialogs[msg_idx].role;
auto content = history_dialogs[msg_idx].content;
Dialog tmp_dia(role, content);
dropped_token_nums += _m_generator.count_dialog_token_nums(tmp_dia);
_m_generator->apply_chat_template(tmp_dia, false, fmt_string);
_m_generator->tokenize(fmt_string, fmt_tokens, true);
dropped_token_nums += fmt_tokens.size();
summary_dialogs += tmp_dia;
if (dropped_token_nums >= drop_threshold) {
msg_idx++;
Expand All @@ -468,22 +481,26 @@ StatusCode Llama3ChatServer::Impl::regenerate_with_cache_dialogs(
summary_dialogs.push_back({"system", "You are an assistant skilled at generating summaries."});
summary_dialogs.push_back(
{"user", fmt::format("Please summarize the multi-turn conversation above "
"in content not exceeding {} tokens.", summary_token_nums)}
"in content not exceeding {} tokens.", summary_token_nums).c_str()}
);

// check summary dialog token nums
auto summary_tokens = _m_generator.count_dialog_token_nums(summary_dialogs);
auto n_ctx = _m_generator.get_model_stat().n_ctx_size;
_m_generator->apply_chat_template(summary_dialogs, false, fmt_string);
_m_generator->tokenize(fmt_string, fmt_tokens, true);
auto summary_tokens = static_cast<double>(fmt_tokens.size());
auto n_ctx = _m_generator->get_model_stat().n_ctx_size;
while (summary_tokens > 0.75 * n_ctx) {
summary_dialogs.messages.erase(summary_dialogs.messages.begin());
summary_tokens = _m_generator.count_dialog_token_nums(summary_dialogs);
_m_generator->apply_chat_template(summary_dialogs, false, fmt_string);
_m_generator->tokenize(fmt_string, fmt_tokens, true);
summary_tokens = static_cast<double>(fmt_tokens.size());
}
LOG(INFO) << "n_tokens: " << summary_tokens << " used before summary";

// generate summary msg
_m_generator.clear_kv_cache_cell();
_m_generator->clear_kv_cache_cell();
std::string summary_msg;
auto status = _m_generator.chat_completion(summary_dialogs, summary_msg);
status = _m_generator->chat_completion(summary_dialogs, summary_msg);
if (status != StatusCode::OK) {
return status;
}
Expand All @@ -496,17 +513,19 @@ StatusCode Llama3ChatServer::Impl::regenerate_with_cache_dialogs(
"conversation. Summary content is {}.Please continue assisting the customer based on it.",
summary_dialogs.size(), summary_msg)
);
LOG(INFO) << "n_tokens: " << _m_generator.count_dialog_token_nums(updated_dialog) << " used after summary";
_m_generator->apply_chat_template(updated_dialog, false, fmt_string);
_m_generator->tokenize(fmt_string, fmt_tokens, true);
LOG(INFO) << "n_tokens: " << fmt_tokens.size() << " used after summary";
for (auto i = msg_idx; i < history_dialogs.size(); ++i) {
updated_dialog.push_back(history_dialogs[i]);
}
_m_user_history_dialogs[task.uuid].clean_cache();
_m_user_history_dialogs[task.uuid] = updated_dialog;

// regenerate response content
_m_generator.clear_kv_cache_cell();
_m_generator->clear_kv_cache_cell();
Dialog cur_dialog = updated_dialog + task.current_dialog;
status = _m_generator.chat_completion(cur_dialog, ctx->gen_out);
status = _m_generator->chat_completion(cur_dialog, ctx->gen_out);

// cache dialog
_m_user_history_dialogs[task.uuid] += task.current_dialog;
Expand Down

0 comments on commit dd4f95c

Please sign in to comment.