Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Nov 29, 2024
1 parent 696b2e3 commit 17f2a70
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 2 deletions.
70 changes: 70 additions & 0 deletions src/models/llm/base_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,74 @@
#ifndef MORTRED_MODEL_SERVER_BASE_GENERATOR_H
#define MORTRED_MODEL_SERVER_BASE_GENERATOR_H

#include "toml/toml.hpp"

#include "common/status_code.h"
#include "models/llm/chat_template/base_chat_template.h"

namespace jinq {
namespace models {
namespace llm {

#define OUT

class BaseLlmGenerator {
public:
/***
*
*/
virtual ~BaseLlmGenerator() = default;

/***
*
* @param config
*/
BaseLlmGenerator() = default;

/***
*
* @param transformer
*/
BaseLlmGenerator(const BaseLlmGenerator &BaseLlmGenerator) = default;

/***
*
* @param transformer
* @return
*/
BaseLlmGenerator &operator=(const BaseLlmGenerator &transformer) = default;

/***
*
* @param cfg
* @return
*/
virtual jinq::common::StatusCode init(const decltype(toml::parse("")) &cfg) = 0;

/***
*
* @param input
* @param output
* @return
*/
virtual jinq::common::StatusCode text_completion(const std::string& prompt, OUT std::string& generate_output) = 0;

/***
*
* @param dialog
* @param generate_output
* @return
*/
virtual jinq::common::StatusCode chat_completion(models::llm::chat_template::Dialog& dialog, OUT std::string& generate_output) = 0;

/***
*
* @return
*/
virtual bool is_successfully_initialized() const = 0;
};
}
}
}

#endif // MORTRED_MODEL_SERVER_BASE_GENERATOR_H
200 changes: 200 additions & 0 deletions src/models/llm/llama/llama3_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,203 @@
************************************************/

#include "llama3_generator.h"

#include "glog/logging.h"

#include "common/cv_utils.h"
#include "common/time_stamp.h"
#include "common/file_path_util.h"
#include "models/llm/llama/llama3.h"
#include "models/llm/chat_template/llama3_chat_template.h"

namespace jinq {
namespace models {
namespace llm {

using jinq::common::CvUtils;
using jinq::common::Timestamp;
using jinq::common::StatusCode;
using jinq::common::FilePathUtil;
using jinq::models::llm::chat_template::Dialog;
using jinq::models::llm::chat_template::ChatMessage;
using jinq::models::llm::chat_template::Llama3ChatTemplate;
using Llama3Ptr = jinq::models::llm::llama::Llama3<std::string, std::string>;

namespace llama {

/***************** Impl Function Sets ******************/

class Llama3Generator::Impl {
public:
/***
*
*/
Impl() = default;

/***
*
*/
~Impl() = default;

/***
*
* @param transformer
*/
Impl(const Impl& transformer) = delete;

/***
*
* @param transformer
* @return
*/
Impl& operator=(const Impl& transformer) = delete;

/***
*
* @param cfg_file_path
* @return
*/
StatusCode init(const decltype(toml::parse("")) &config);

/***
*
* @param input
* @param output
* @return
*/
StatusCode text_completion(const std::string& prompt, OUT std::string& generate_output);

/***
*
* @param dialog
* @param generate_output
* @return
*/
StatusCode chat_completion(models::llm::chat_template::Dialog& dialog, OUT std::string& generate_output);

/***
*
* @return
*/
bool is_successfully_initialized() const {
return _m_successfully_initialized;
};

private:
// init flag
bool _m_successfully_initialized = false;
// llm model
Llama3Ptr _m_model;
// chat template
Llama3ChatTemplate _m_chat_template;
};

/***
*
* @param config
* @return
*/
StatusCode Llama3Generator::Impl::init(const decltype(toml::parse("")) &config) {
// init llama3 model
auto status = _m_model.init(config);
if (status != StatusCode::OK) {
_m_successfully_initialized = false;
} else {
_m_successfully_initialized = true;
}
return StatusCode::OK;
}

/***
*
* @param prompt
* @param generate_output
* @return
*/
StatusCode Llama3Generator::Impl::text_completion(const std::string &prompt, std::string &generate_output) {
auto status = _m_model.run(prompt, generate_output);
return status;
}

/***
*
* @param dialog
* @param generate_output
* @return
*/
StatusCode Llama3Generator::Impl::chat_completion(models::llm::chat_template::Dialog &dialog, std::string &generate_output) {
// template format dialog
std::string fmt_prompt;
auto status = _m_chat_template.apply_chat_template(dialog, fmt_prompt);
if (status != StatusCode::OK) {
LOG(ERROR) << "apply chat template for dialog failed, status code: " << status;
return status;
}

// chat completion
status = _m_model.run(fmt_prompt, generate_output);

// log dialog messages
for (auto& msg : dialog) {
LOG(INFO) << fmt::format("{}: {}", msg.role, msg.content);
}
LOG(INFO) << fmt::format("assistant: {}", generate_output);

return status;
}

/************* Export Function Sets *************/

/***
*
*/
Llama3Generator::Llama3Generator() {
_m_pimpl = std::make_unique<Impl>();
}

/***
*
*/
Llama3Generator::~Llama3Generator() = default;

/***
*
* @param cfg
* @return
*/
StatusCode Llama3Generator::init(const decltype(toml::parse("")) &cfg) {
return _m_pimpl->init(cfg);
}

/***
*
* @return
*/
bool Llama3Generator::is_successfully_initialized() const {
return _m_pimpl->is_successfully_initialized();
}

/***
*
* @param prompt
* @param generate_output
* @return
*/
StatusCode Llama3Generator::text_completion(const std::string &prompt, std::string &generate_output) {
return _m_pimpl-> text_completion(prompt, generate_output);
}

/***
*
* @param dialog
* @param generate_output
* @return
*/
StatusCode Llama3Generator::chat_completion(models::llm::chat_template::Dialog &dialog, std::string &generate_output) {
return _m_pimpl->chat_completion(dialog, generate_output);
}

}
}
}
}
81 changes: 79 additions & 2 deletions src/models/llm/llama/llama3_generator.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,90 @@
/************************************************
* Copyright MaybeShewill-CV. All Rights Reserved.
* Author: MaybeShewill-CV
* File: llama3_generator.h
* File: Llama3Generator_generator.h
* Date: 24-11-28
************************************************/

#ifndef MORTRED_MODEL_SERVER_LLAMA3_GENERATOR_H
#define MORTRED_MODEL_SERVER_LLAMA3_GENERATOR_H

class llama3_generator {};
#include <memory>

#include "toml/toml.hpp"

#include "models/base_model.h"
#include "models/model_io_define.h"
#include "models/llm/base_generator.h"
#include "common/status_code.h"

namespace jinq {
namespace models {
namespace llm {
namespace llama {

class Llama3Generator : public BaseLlmGenerator {
public:
/***
* constructor
* @param config
*/
Llama3Generator();

/***
*
*/
~Llama3Generator() override;

/***
* constructor
* @param transformer
*/
Llama3Generator(const Llama3Generator &transformer) = delete;

/***
* constructor
* @param transformer
* @return
*/
Llama3Generator &operator=(const Llama3Generator &transformer) = delete;

/***
*
* @param toml
* @return
*/
jinq::common::StatusCode init(const decltype(toml::parse("")) &cfg) override;

/***
*
* @param input
* @param output
* @return
*/
jinq::common::StatusCode text_completion(const std::string& prompt, OUT std::string& generate_output) override;

/***
*
* @param dialog
* @param generate_output
* @return
*/
jinq::common::StatusCode chat_completion(models::llm::chat_template::Dialog& dialog, OUT std::string& generate_output) override;

/***
* if model successfully initialized
* @return
*/
bool is_successfully_initialized() const override;

private:
class Impl;
std::unique_ptr<Impl> _m_pimpl;
};

}
}
}
}

#endif // MORTRED_MODEL_SERVER_LLAMA3_GENERATOR_H

0 comments on commit 17f2a70

Please sign in to comment.