Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Dec 4, 2024
1 parent ad1e25d commit b9ba45d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 6 deletions.
18 changes: 17 additions & 1 deletion src/apps/model_benchmark/llm/llama3_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,26 @@ int main(int argc, char** argv) {
{"system", "You're a smart AI assistant from Mortred Company"},
{"user", "Who are you?"},
{"assistant", "I am a ai assistant"},
{"user", "Where are you from?"},
{"user", "Who is your favorite singer?"},
};
std::string gen_out;
generator.chat_completion(dialog, gen_out);
dialog.messages.emplace_back("assistant", gen_out);
LOG(INFO) << "assistant: " << gen_out;

Dialog new_dialog;
new_dialog.messages.emplace_back("user", "answer last question again");
generator.chat_completion(new_dialog, gen_out);
dialog.messages.emplace_back("assistant", gen_out);
LOG(INFO) << "assistant: " << gen_out;

generator.clear_kv_cache_cell();
generator.chat_completion(new_dialog, gen_out);
LOG(INFO) << "assistant: " << gen_out;

generator.clear_kv_cache_cell();
auto status = generator.chat_completion(dialog, gen_out);
LOG(INFO) << "assistant: " << gen_out;

return status;
}
5 changes: 5 additions & 0 deletions src/models/llm/llama/llama3.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class Llama3 : public jinq::models::BaseAiModel<INPUT, OUTPUT> {
*/
ModelStatus get_model_stat() const;

/***
*
*/
void clear_kv_cache_cell() const;

/***
* if model successfully initialized
* @return
Expand Down
31 changes: 26 additions & 5 deletions src/models/llm/llama/llama3.inl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ public:
*/
ModelStatus get_model_stat() const;

/***
*
*/
void clear_kv_cache_cell() const;

/***
*
* @return
Expand Down Expand Up @@ -328,6 +333,17 @@ ModelStatus Llama3<INPUT, OUTPUT>::Impl::get_model_stat() const {
return stat;
}

/***
*
* @tparam INPUT
* @tparam OUTPUT
* @return
*/
template <typename INPUT, typename OUTPUT>
void Llama3<INPUT, OUTPUT>::Impl::clear_kv_cache_cell() const {
llama_kv_cache_clear(_m_ctx);
}

/***
*
* @tparam INPUT
Expand All @@ -338,15 +354,12 @@ ModelStatus Llama3<INPUT, OUTPUT>::Impl::get_model_stat() const {
template <typename INPUT, typename OUTPUT>
StatusCode Llama3<INPUT, OUTPUT>::Impl::llama_generate(std::vector<llama_token>& prompt_tokens, std::string& generate_out) {
// prepare a batch for the prompt
LOG(INFO) << "prompt token size: " << prompt_tokens.size();
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
llama_token new_token_id;
while (true) {
// check if we have enough space in the context to evaluate this batch
int n_ctx = llama_n_ctx(_m_ctx);
int n_ctx_used = llama_get_kv_cache_used_cells(_m_ctx);
LOG(INFO) << "context size: " << n_ctx;
LOG(INFO) << "kv cache size: " << n_ctx_used;
if (n_ctx_used + batch.n_tokens > n_ctx) {
LOG(ERROR) << "context size exceeded";
return StatusCode::MODEL_RUN_SESSION_FAILED;
Expand Down Expand Up @@ -377,8 +390,6 @@ StatusCode Llama3<INPUT, OUTPUT>::Impl::llama_generate(std::vector<llama_token>&
return StatusCode::MODEL_RUN_SESSION_FAILED;
}
std::string piece(buf, n);
// printf("%s", piece.c_str());
// fflush(stdout);
generate_out += piece;

// prepare the next batch with the sampled token
Expand Down Expand Up @@ -468,6 +479,16 @@ ModelStatus Llama3<INPUT, OUTPUT>::get_model_stat() const {
return _m_pimpl->get_model_stat();
}

/***
*
* @tparam INPUT
* @tparam OUTPUT
*/
template <typename INPUT, typename OUTPUT>
void Llama3<INPUT, OUTPUT>::clear_kv_cache_cell() const {
return _m_pimpl->clear_kv_cache_cell();
}

}
}
}
Expand Down
14 changes: 14 additions & 0 deletions src/models/llm/llama/llama3_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ class Llama3Generator::Impl {
*/
StatusCode chat_completion(models::llm::chat_template::Dialog& dialog, OUT std::string& generate_output);

/***
*
*/
void clear_kv_cache_cell() {
return _m_model.clear_kv_cache_cell();
}

/***
*
* @return
Expand Down Expand Up @@ -227,6 +234,13 @@ StatusCode Llama3Generator::chat_completion(models::llm::chat_template::Dialog &
return _m_pimpl->chat_completion(dialog, generate_output);
}

/***
*
*/
void Llama3Generator::clear_kv_cache_cell() {
return _m_pimpl->clear_kv_cache_cell();
}

}
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/models/llm/llama/llama3_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class Llama3Generator : public BaseLlmGenerator {
*/
jinq::common::StatusCode chat_completion(models::llm::chat_template::Dialog& dialog, OUT std::string& generate_output) override;

/***
*
*/
void clear_kv_cache_cell();

/***
* if model successfully initialized
* @return
Expand Down

0 comments on commit b9ba45d

Please sign in to comment.