diff --git a/src/models/llm/llama/llama3.h b/src/models/llm/llama/llama3.h index 644b9dd..f595cea 100644 --- a/src/models/llm/llama/llama3.h +++ b/src/models/llm/llama/llama3.h @@ -88,6 +88,13 @@ class Llama3 : public jinq::models::BaseAiModel { */ void clear_kv_cache_cell() const; + /*** + * + * @param prompt + * @return + */ + int32_t count_prompt_token_nums(const std::string& prompt) const; + /*** * if model successfully initialized * @return @@ -98,6 +105,7 @@ class Llama3 : public jinq::models::BaseAiModel { class Impl; std::unique_ptr _m_pimpl; }; + } } } diff --git a/src/models/llm/llama/llama3.inl b/src/models/llm/llama/llama3.inl index 2023736..7d0af47 100644 --- a/src/models/llm/llama/llama3.inl +++ b/src/models/llm/llama/llama3.inl @@ -136,6 +136,8 @@ public: */ void clear_kv_cache_cell() const; + int32_t count_prompt_token_nums(const std::string& prompt) const; + /*** * * @return @@ -344,6 +346,22 @@ void Llama3::Impl::clear_kv_cache_cell() const { llama_kv_cache_clear(_m_ctx); } +/*** + * + * @tparam INPUT + * @tparam OUTPUT + * @param prompt + * @return + */ +template +int32_t Llama3::Impl::count_prompt_token_nums(const std::string &prompt) const { + if (prompt.empty()) { + return 0; + } + auto n_prompt_tokens = llama_tokenize(_m_model, prompt.c_str(), static_cast(prompt.size()), nullptr, 0, true, true); + return -n_prompt_tokens; +} + /*** * * @tparam INPUT @@ -362,7 +380,7 @@ StatusCode Llama3::Impl::llama_generate(std::vector& int n_ctx_used = llama_get_kv_cache_used_cells(_m_ctx); if (n_ctx_used + batch.n_tokens > n_ctx) { LOG(ERROR) << "context size exceeded"; - return StatusCode::MODEL_RUN_SESSION_FAILED; + return StatusCode::LLM_CONTEXT_SIZE_EXCEEDED; } auto status = llama_decode(_m_ctx, batch); @@ -489,6 +507,18 @@ void Llama3::clear_kv_cache_cell() const { return _m_pimpl->clear_kv_cache_cell(); } +/*** + * + * @tparam INPUT + * @tparam OUTPUT + * @param prompt + * @return + */ +template +int32_t Llama3::count_prompt_token_nums(const std::string &prompt) const { + return _m_pimpl->count_prompt_token_nums(prompt); +} + } } }