From 696b2e38925af2b49f21f18163bd7204bafea008 Mon Sep 17 00:00:00 2001 From: luoyao Date: Thu, 28 Nov 2024 21:06:04 +0800 Subject: [PATCH] update --- .../model_benchmark/llm/llama3_benchmark.cpp | 31 ++--- src/models/CMakeLists.txt | 2 +- src/models/llm/base_generator.h | 11 ++ .../llm/chat_template/base_chat_template.h | 70 ++++++++++++ .../llm/chat_template/llama3_chat_template.h | 66 +++++++++++ .../chat_template/llama3_chat_template.inl | 106 ++++++++++++++++++ src/models/llm/llama/llama3_generator.cpp | 8 ++ src/models/llm/llama/llama3_generator.h | 13 +++ 8 files changed, 292 insertions(+), 15 deletions(-) create mode 100644 src/models/llm/base_generator.h create mode 100644 src/models/llm/chat_template/base_chat_template.h create mode 100644 src/models/llm/chat_template/llama3_chat_template.h create mode 100644 src/models/llm/chat_template/llama3_chat_template.inl create mode 100644 src/models/llm/llama/llama3_generator.cpp create mode 100644 src/models/llm/llama/llama3_generator.h diff --git a/src/apps/model_benchmark/llm/llama3_benchmark.cpp b/src/apps/model_benchmark/llm/llama3_benchmark.cpp index d8fd488..13470f4 100644 --- a/src/apps/model_benchmark/llm/llama3_benchmark.cpp +++ b/src/apps/model_benchmark/llm/llama3_benchmark.cpp @@ -12,12 +12,12 @@ #include "common/file_path_util.h" #include "common/time_stamp.h" -#include "models/model_io_define.h" -#include "models/llm/llama/llama3.h" +#include "models/llm/llama/llama3_generator.h" using jinq::common::FilePathUtil; -using jinq::common::Timestamp; -using jinq::models::llm::llama::Llama3; +using jinq::models::llm::chat_template::Dialog; +using jinq::models::llm::chat_template::ChatMessage; +using jinq::models::llm::llama::Llama3Generator; int main(int argc, char** argv) { @@ -35,19 +35,22 @@ int main(int argc, char** argv) { } // construct llama3 model - Llama3 model; + Llama3Generator generator; auto cfg = toml::parse(cfg_file_path); - model.init(cfg); - if (!model.is_successfully_initialized()) { - LOG(INFO) << "llama3 model init failed"; + generator.init(cfg); + if (!generator.is_successfully_initialized()) { + LOG(INFO) << "llama3 generator init failed"; return -1; } - std::string input = "\n<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nWho creates you?<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n"; - LOG(INFO) << "input prompt text: " << input; - std::string out; - model.run(input, out); - LOG(INFO) << "generated output: " << out; + Dialog dialog = { + {"system", "You're a smart AI assistant from Mortred Company"}, + {"user", "Who are you?"}, + {"assistant", "I am a ai assistant"}, + {"user", "Where are you from?"}, + }; + std::string gen_out; + auto status = generator.chat_completion(dialog, gen_out); - return 0; + return status; } diff --git a/src/models/CMakeLists.txt b/src/models/CMakeLists.txt index 7b29960..8c322ce 100644 --- a/src/models/CMakeLists.txt +++ b/src/models/CMakeLists.txt @@ -20,6 +20,6 @@ target_link_libraries( ${CUDA_LIBRARIES} ${WORKFLOW_LIBS} ${LLAMA_LIBRARIES} -# ${GGML_LIBRARIES} + ${GGML_LIBRARIES} ) set_target_properties(models PROPERTIES LINKER_LANGUAGE CXX) diff --git a/src/models/llm/base_generator.h b/src/models/llm/base_generator.h new file mode 100644 index 0000000..7c7c7f3 --- /dev/null +++ b/src/models/llm/base_generator.h @@ -0,0 +1,11 @@ +/************************************************ + * Copyright MaybeShewill-CV. All Rights Reserved. + * Author: MaybeShewill-CV + * File: base_generator.h + * Date: 24-11-28 + ************************************************/ + +#ifndef MORTRED_MODEL_SERVER_BASE_GENERATOR_H +#define MORTRED_MODEL_SERVER_BASE_GENERATOR_H + +#endif // MORTRED_MODEL_SERVER_BASE_GENERATOR_H diff --git a/src/models/llm/chat_template/base_chat_template.h b/src/models/llm/chat_template/base_chat_template.h new file mode 100644 index 0000000..de7f36f --- /dev/null +++ b/src/models/llm/chat_template/base_chat_template.h @@ -0,0 +1,70 @@ +/************************************************ + * Copyright MaybeShewill-CV. All Rights Reserved. + * Author: MaybeShewill-CV + * File: base_chat_template.h + * Date: 24-11-26 + ************************************************/ + +#ifndef MORTRED_MODEL_SERVER_BASE_CHAT_TEMPLATE_H +#define MORTRED_MODEL_SERVER_BASE_CHAT_TEMPLATE_H + +#include +#include +#include + +#include "common/status_code.h" + +namespace jinq { +namespace models { +namespace llm { +namespace chat_template { + +struct ChatMessage { + std::string role; + std::string content; + + ChatMessage(std::string r, std::string c) : role(std::move(r)), content(std::move(c)) {} +}; + +using Dialog = std::vector; + +class BaseChatTemplate { + public: + /*** + * + */ + virtual ~BaseChatTemplate() = default; + + /*** + * + * @param config + */ + BaseChatTemplate() = default; + + /*** + * + * @param transformer + */ + BaseChatTemplate(const BaseChatTemplate &BaseChatTemplate) = default; + + /*** + * + * @param transformer + * @return + */ + BaseChatTemplate &operator=(const BaseChatTemplate &transformer) = default; + + /*** + * + * @param input + * @param output + * @return + */ + virtual jinq::common::StatusCode apply_chat_template(const std::vector& messages, std::string& out_fmt_str) = 0; +}; +} +} +} +} + +#endif // MORTRED_MODEL_SERVER_BASE_CHAT_TEMPLATE_H diff --git a/src/models/llm/chat_template/llama3_chat_template.h b/src/models/llm/chat_template/llama3_chat_template.h new file mode 100644 index 0000000..7bb4402 --- /dev/null +++ b/src/models/llm/chat_template/llama3_chat_template.h @@ -0,0 +1,66 @@ +/************************************************ + * Copyright MaybeShewill-CV. All Rights Reserved. + * Author: MaybeShewill-CV + * File: Llama3ChatTemplate_chat_template.h + * Date: 24-11-26 + ************************************************/ + +#ifndef MORTRED_MODEL_SERVER_LLAMA3_CHAT_TEMPLATE_CHAT_TEMPLATE_H +#define MORTRED_MODEL_SERVER_LLAMA3_CHAT_TEMPLATE_CHAT_TEMPLATE_H + +#include + +#include "models/llm/chat_template/base_chat_template.h" + +namespace jinq { +namespace models { +namespace llm { +namespace chat_template { + +class Llama3ChatTemplate : public BaseChatTemplate { + public: + /*** + * constructor + * @param config + */ + Llama3ChatTemplate(); + + /*** + * + */ + ~Llama3ChatTemplate() override; + + /*** + * constructor + * @param transformer + */ + Llama3ChatTemplate(const Llama3ChatTemplate &transformer) = delete; + + /*** + * constructor + * @param transformer + * @return + */ + Llama3ChatTemplate &operator=(const Llama3ChatTemplate &transformer) = delete; + + /*** + * + * @param input + * @param output + * @return + */ + jinq::common::StatusCode apply_chat_template(const std::vector& messages, std::string& out_fmt_str) override; + + private: + class Impl; + std::unique_ptr _m_pimpl; +}; + +} +} +} +} + +#include "llama3_chat_template.inl" + +#endif // MORTRED_MODEL_SERVER_LLAMA3_CHAT_TEMPLATE_CHAT_TEMPLATE_H diff --git a/src/models/llm/chat_template/llama3_chat_template.inl b/src/models/llm/chat_template/llama3_chat_template.inl new file mode 100644 index 0000000..fb09e80 --- /dev/null +++ b/src/models/llm/chat_template/llama3_chat_template.inl @@ -0,0 +1,106 @@ +/************************************************ + * Copyright MaybeShewill-CV. All Rights Reserved. + * Author: MaybeShewill-CV + * File: Llama3ChatTemplate_chat_template.inl + * Date: 24-11-26 + ************************************************/ + +#include "models/llm/chat_template/llama3_chat_template.h" + +#include "fmt/format.h" + +namespace jinq { +namespace models { +namespace llm { + +using jinq::common::StatusCode; + +namespace chat_template { + +class Llama3ChatTemplate::Impl { + public: + /*** + * + */ + Impl() = default; + + /*** + * + */ + ~Impl() = default; + + /*** + * + * @param transformer + */ + Impl(const Impl& transformer) = delete; + + /*** + * + * @param transformer + * @return + */ + Impl& operator=(const Impl& transformer) = delete; + + /*** + * + * @param in + * @param out + * @return + */ + StatusCode apply_chat_template(const std::vector& messages, std::string& out_fmt_str); + + private: + std::string _m_header_fmt = "<|start_header_id|>{}<|end_header_id|>\n\n"; + std::string _m_message_fmt = "<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>"; + +}; + +/*** + * + * @param messages + * @param out_fmt_str + * @return + */ +StatusCode Llama3ChatTemplate::Impl::apply_chat_template(const std::vector &messages, std::string &out_fmt_str) { + if (messages.empty()) { + return StatusCode::TOKENIZE_FAILED; + } + std::string fmt_dialog; + for (auto& message : messages) { + fmt_dialog += fmt::format(_m_message_fmt, message.role, message.content); + } + fmt_dialog += fmt::format(_m_header_fmt, "assistant"); + out_fmt_str = fmt_dialog; + return StatusCode::OK; +} + + +/************* Export Function Sets *************/ + +/*** + * + */ +Llama3ChatTemplate::Llama3ChatTemplate() { + _m_pimpl = std::make_unique(); +} + +/*** + * + */ +Llama3ChatTemplate::~Llama3ChatTemplate() = default; + +/*** + * + * @param messages + * @param out_fmt_str + * @return + */ +StatusCode Llama3ChatTemplate::apply_chat_template(const std::vector &messages, std::string &out_fmt_str) { + return _m_pimpl->apply_chat_template(messages, out_fmt_str); +} + +} +} +} +} diff --git a/src/models/llm/llama/llama3_generator.cpp b/src/models/llm/llama/llama3_generator.cpp new file mode 100644 index 0000000..a5a867e --- /dev/null +++ b/src/models/llm/llama/llama3_generator.cpp @@ -0,0 +1,8 @@ +/************************************************ + * Copyright MaybeShewill-CV. All Rights Reserved. + * Author: MaybeShewill-CV + * File: llama3_generator.cpp + * Date: 24-11-28 + ************************************************/ + +#include "llama3_generator.h" diff --git a/src/models/llm/llama/llama3_generator.h b/src/models/llm/llama/llama3_generator.h new file mode 100644 index 0000000..07b19f9 --- /dev/null +++ b/src/models/llm/llama/llama3_generator.h @@ -0,0 +1,13 @@ +/************************************************ + * Copyright MaybeShewill-CV. All Rights Reserved. + * Author: MaybeShewill-CV + * File: llama3_generator.h + * Date: 24-11-28 + ************************************************/ + +#ifndef MORTRED_MODEL_SERVER_LLAMA3_GENERATOR_H +#define MORTRED_MODEL_SERVER_LLAMA3_GENERATOR_H + +class llama3_generator {}; + +#endif // MORTRED_MODEL_SERVER_LLAMA3_GENERATOR_H