diff --git a/src/models/llm/base_generator.h b/src/models/llm/base_generator.h index 7c7c7f3..0b21597 100644 --- a/src/models/llm/base_generator.h +++ b/src/models/llm/base_generator.h @@ -8,4 +8,74 @@ #ifndef MORTRED_MODEL_SERVER_BASE_GENERATOR_H #define MORTRED_MODEL_SERVER_BASE_GENERATOR_H +#include "toml/toml.hpp" + +#include "common/status_code.h" +#include "models/llm/chat_template/base_chat_template.h" + +namespace jinq { +namespace models { +namespace llm { + +#define OUT + +class BaseLlmGenerator { + public: + /*** + * + */ + virtual ~BaseLlmGenerator() = default; + + /*** + * + * @param config + */ + BaseLlmGenerator() = default; + + /*** + * + * @param transformer + */ + BaseLlmGenerator(const BaseLlmGenerator &BaseLlmGenerator) = default; + + /*** + * + * @param transformer + * @return + */ + BaseLlmGenerator &operator=(const BaseLlmGenerator &transformer) = default; + + /*** + * + * @param cfg + * @return + */ + virtual jinq::common::StatusCode init(const decltype(toml::parse("")) &cfg) = 0; + + /*** + * + * @param input + * @param output + * @return + */ + virtual jinq::common::StatusCode text_completion(const std::string& prompt, OUT std::string& generate_output) = 0; + + /*** + * + * @param dialog + * @param generate_output + * @return + */ + virtual jinq::common::StatusCode chat_completion(models::llm::chat_template::Dialog& dialog, OUT std::string& generate_output) = 0; + + /*** + * + * @return + */ + virtual bool is_successfully_initialized() const = 0; +}; +} +} +} + #endif // MORTRED_MODEL_SERVER_BASE_GENERATOR_H diff --git a/src/models/llm/llama/llama3_generator.cpp b/src/models/llm/llama/llama3_generator.cpp index a5a867e..750dff0 100644 --- a/src/models/llm/llama/llama3_generator.cpp +++ b/src/models/llm/llama/llama3_generator.cpp @@ -6,3 +6,203 @@ ************************************************/ #include "llama3_generator.h" + +#include "glog/logging.h" + +#include "common/cv_utils.h" +#include "common/time_stamp.h" +#include "common/file_path_util.h" +#include "models/llm/llama/llama3.h" +#include "models/llm/chat_template/llama3_chat_template.h" + +namespace jinq { +namespace models { +namespace llm { + +using jinq::common::CvUtils; +using jinq::common::Timestamp; +using jinq::common::StatusCode; +using jinq::common::FilePathUtil; +using jinq::models::llm::chat_template::Dialog; +using jinq::models::llm::chat_template::ChatMessage; +using jinq::models::llm::chat_template::Llama3ChatTemplate; +using Llama3Ptr = jinq::models::llm::llama::Llama3; + +namespace llama { + +/***************** Impl Function Sets ******************/ + +class Llama3Generator::Impl { + public: + /*** + * + */ + Impl() = default; + + /*** + * + */ + ~Impl() = default; + + /*** + * + * @param transformer + */ + Impl(const Impl& transformer) = delete; + + /*** + * + * @param transformer + * @return + */ + Impl& operator=(const Impl& transformer) = delete; + + /*** + * + * @param cfg_file_path + * @return + */ + StatusCode init(const decltype(toml::parse("")) &config); + + /*** + * + * @param input + * @param output + * @return + */ + StatusCode text_completion(const std::string& prompt, OUT std::string& generate_output); + + /*** + * + * @param dialog + * @param generate_output + * @return + */ + StatusCode chat_completion(models::llm::chat_template::Dialog& dialog, OUT std::string& generate_output); + + /*** + * + * @return + */ + bool is_successfully_initialized() const { + return _m_successfully_initialized; + }; + + private: + // init flag + bool _m_successfully_initialized = false; + // llm model + Llama3Ptr _m_model; + // chat template + Llama3ChatTemplate _m_chat_template; +}; + +/*** + * + * @param config + * @return + */ +StatusCode Llama3Generator::Impl::init(const decltype(toml::parse("")) &config) { + // init llama3 model + auto status = _m_model.init(config); + if (status != StatusCode::OK) { + _m_successfully_initialized = false; + } else { + _m_successfully_initialized = true; + } + return StatusCode::OK; +} + +/*** + * + * @param prompt + * @param generate_output + * @return + */ +StatusCode Llama3Generator::Impl::text_completion(const std::string &prompt, std::string &generate_output) { + auto status = _m_model.run(prompt, generate_output); + return status; +} + +/*** + * + * @param dialog + * @param generate_output + * @return + */ +StatusCode Llama3Generator::Impl::chat_completion(models::llm::chat_template::Dialog &dialog, std::string &generate_output) { + // template format dialog + std::string fmt_prompt; + auto status = _m_chat_template.apply_chat_template(dialog, fmt_prompt); + if (status != StatusCode::OK) { + LOG(ERROR) << "apply chat template for dialog failed, status code: " << status; + return status; + } + + // chat completion + status = _m_model.run(fmt_prompt, generate_output); + + // log dialog messages + for (auto& msg : dialog) { + LOG(INFO) << fmt::format("{}: {}", msg.role, msg.content); + } + LOG(INFO) << fmt::format("assistant: {}", generate_output); + + return status; +} + +/************* Export Function Sets *************/ + +/*** + * + */ +Llama3Generator::Llama3Generator() { + _m_pimpl = std::make_unique(); +} + +/*** + * + */ +Llama3Generator::~Llama3Generator() = default; + +/*** + * + * @param cfg + * @return + */ +StatusCode Llama3Generator::init(const decltype(toml::parse("")) &cfg) { + return _m_pimpl->init(cfg); +} + +/*** + * + * @return + */ +bool Llama3Generator::is_successfully_initialized() const { + return _m_pimpl->is_successfully_initialized(); +} + +/*** + * + * @param prompt + * @param generate_output + * @return + */ +StatusCode Llama3Generator::text_completion(const std::string &prompt, std::string &generate_output) { + return _m_pimpl-> text_completion(prompt, generate_output); +} + +/*** + * + * @param dialog + * @param generate_output + * @return + */ +StatusCode Llama3Generator::chat_completion(models::llm::chat_template::Dialog &dialog, std::string &generate_output) { + return _m_pimpl->chat_completion(dialog, generate_output); +} + +} +} +} +} \ No newline at end of file diff --git a/src/models/llm/llama/llama3_generator.h b/src/models/llm/llama/llama3_generator.h index 07b19f9..f7f8094 100644 --- a/src/models/llm/llama/llama3_generator.h +++ b/src/models/llm/llama/llama3_generator.h @@ -1,13 +1,90 @@ /************************************************ * Copyright MaybeShewill-CV. All Rights Reserved. * Author: MaybeShewill-CV - * File: llama3_generator.h + * File: Llama3Generator_generator.h * Date: 24-11-28 ************************************************/ #ifndef MORTRED_MODEL_SERVER_LLAMA3_GENERATOR_H #define MORTRED_MODEL_SERVER_LLAMA3_GENERATOR_H -class llama3_generator {}; +#include + +#include "toml/toml.hpp" + +#include "models/base_model.h" +#include "models/model_io_define.h" +#include "models/llm/base_generator.h" +#include "common/status_code.h" + +namespace jinq { +namespace models { +namespace llm { +namespace llama { + +class Llama3Generator : public BaseLlmGenerator { + public: + /*** + * constructor + * @param config + */ + Llama3Generator(); + + /*** + * + */ + ~Llama3Generator() override; + + /*** + * constructor + * @param transformer + */ + Llama3Generator(const Llama3Generator &transformer) = delete; + + /*** + * constructor + * @param transformer + * @return + */ + Llama3Generator &operator=(const Llama3Generator &transformer) = delete; + + /*** + * + * @param toml + * @return + */ + jinq::common::StatusCode init(const decltype(toml::parse("")) &cfg) override; + + /*** + * + * @param input + * @param output + * @return + */ + jinq::common::StatusCode text_completion(const std::string& prompt, OUT std::string& generate_output) override; + + /*** + * + * @param dialog + * @param generate_output + * @return + */ + jinq::common::StatusCode chat_completion(models::llm::chat_template::Dialog& dialog, OUT std::string& generate_output) override; + + /*** + * if model successfully initialized + * @return + */ + bool is_successfully_initialized() const override; + + private: + class Impl; + std::unique_ptr _m_pimpl; +}; + +} +} +} +} #endif // MORTRED_MODEL_SERVER_LLAMA3_GENERATOR_H