Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Dec 4, 2024
1 parent 1894a28 commit 0e68a1a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
11 changes: 11 additions & 0 deletions src/models/llm/llama/llama3.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ namespace models {
namespace llm {
namespace llama {

struct ModelStatus {
uint32_t n_ctx_size;
int32_t kv_cache_cell_nums;
};

template <typename INPUT, typename OUTPUT>
class Llama3 : public jinq::models::BaseAiModel<INPUT, OUTPUT> {
public:
Expand Down Expand Up @@ -72,6 +77,12 @@ class Llama3 : public jinq::models::BaseAiModel<INPUT, OUTPUT> {
*/
jinq::common::StatusCode tokenize_prompt(const std::string &prompt, std::vector<llama_token> &prompt_tokens);

/***
*
* @return
*/
ModelStatus get_model_stat() const;

/***
* if model successfully initialized
* @return
Expand Down
58 changes: 46 additions & 12 deletions src/models/llm/llama/llama3.inl
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ typename std::enable_if<std::is_same<INPUT, char*>::value, internal_input>::type
* @return
*/
template <typename INPUT>
typename std::enable_if<std::is_same<INPUT, std::string>::value, internal_input>::type transform_input(
const INPUT& in) {
typename std::enable_if<std::is_same<INPUT, std::string>::value, internal_input>::type transform_input(const INPUT& in) {
return in;
}

Expand Down Expand Up @@ -126,6 +125,12 @@ public:
*/
StatusCode tokenize_prompt(const std::string& prompt, std::vector<llama_token>& prompt_tokens);

/***
*
* @return
*/
ModelStatus get_model_stat() const;

/***
*
* @return
Expand Down Expand Up @@ -287,15 +292,13 @@ StatusCode Llama3<INPUT, OUTPUT>::Impl::run(const INPUT& in, OUTPUT& out) {
* @return
*/
template <typename INPUT, typename OUTPUT>
StatusCode Llama3<INPUT, OUTPUT>::Impl::tokenize_prompt(const std::string& prompt,
std::vector<llama_token>& prompt_tokens) {
StatusCode Llama3<INPUT, OUTPUT>::Impl::tokenize_prompt(const std::string& prompt, std::vector<llama_token>& prompt_tokens) {
if (prompt.empty()) {
LOG(WARNING) << "input prompt is empty";
return StatusCode::TOKENIZE_FAILED;
}

auto n_prompt_tokens = llama_tokenize(_m_model, prompt.c_str(), static_cast<int32_t>(prompt.size()), nullptr, 0, true,
true);
auto n_prompt_tokens = llama_tokenize(_m_model, prompt.c_str(), static_cast<int32_t>(prompt.size()), nullptr, 0, true, true);
n_prompt_tokens *= -1;
prompt_tokens.resize(n_prompt_tokens);
auto prompt_size = static_cast<int32_t >(prompt.size());
Expand All @@ -311,6 +314,20 @@ StatusCode Llama3<INPUT, OUTPUT>::Impl::tokenize_prompt(const std::string& promp
return StatusCode::OK;
}

/***
*
* @tparam INPUT
* @tparam OUTPUT
* @return
*/
template <typename INPUT, typename OUTPUT>
ModelStatus Llama3<INPUT, OUTPUT>::Impl::get_model_stat() const {
ModelStatus stat{};
stat.n_ctx_size = llama_n_ctx(_m_ctx);
stat.kv_cache_cell_nums = llama_get_kv_cache_used_cells(_m_ctx);
return stat;
}

/***
*
* @tparam INPUT
Expand All @@ -319,22 +336,28 @@ StatusCode Llama3<INPUT, OUTPUT>::Impl::tokenize_prompt(const std::string& promp
* @return
*/
template <typename INPUT, typename OUTPUT>
StatusCode Llama3<INPUT, OUTPUT>::Impl::llama_generate(std::vector<llama_token>& prompt_tokens,
std::string& generate_out) {
StatusCode Llama3<INPUT, OUTPUT>::Impl::llama_generate(std::vector<llama_token>& prompt_tokens, std::string& generate_out) {
// prepare a batch for the prompt
LOG(INFO) << "prompt token size: " << prompt_tokens.size();
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
llama_token new_token_id;
while (true) {
// check if we have enough space in the context to evaluate this batch
int n_ctx = llama_n_ctx(_m_ctx);
int n_ctx_used = llama_get_kv_cache_used_cells(_m_ctx);
LOG(INFO) << "context size: " << n_ctx;
LOG(INFO) << "kv cache size: " << n_ctx_used;
if (n_ctx_used + batch.n_tokens > n_ctx) {
LOG(ERROR) << "context size exceeded";
return StatusCode::MODEL_RUN_SESSION_FAILED;
}

if (llama_decode(_m_ctx, batch)) {
LOG(ERROR) << "llama generate failed";
auto status = llama_decode(_m_ctx, batch);
if (status == 1) {
LOG(WARNING) << "llama generate failed. could not find a KV slot for the batch "
"(try reducing the size of the batch or increase the context)";
} else if (status < 0) {
LOG(ERROR) << "llama decode failed code: " << status;
return StatusCode::MODEL_RUN_SESSION_FAILED;
}

Expand All @@ -354,8 +377,8 @@ StatusCode Llama3<INPUT, OUTPUT>::Impl::llama_generate(std::vector<llama_token>&
return StatusCode::MODEL_RUN_SESSION_FAILED;
}
std::string piece(buf, n);
// printf("%s", piece.c_str());
// fflush(stdout);
// printf("%s", piece.c_str());
// fflush(stdout);
generate_out += piece;

// prepare the next batch with the sampled token
Expand Down Expand Up @@ -434,6 +457,17 @@ StatusCode Llama3<INPUT, OUTPUT>::tokenize_prompt(const std::string& prompt, std
return _m_pimpl->tokenize_prompt(prompt, prompt_tokens);
}

/***
*
* @tparam INPUT
* @tparam OUTPUT
* @return
*/
template <typename INPUT, typename OUTPUT>
ModelStatus Llama3<INPUT, OUTPUT>::get_model_stat() const {
return _m_pimpl->get_model_stat();
}

}
}
}
Expand Down

0 comments on commit 0e68a1a

Please sign in to comment.