diff --git a/src/models/llm/llama/llama3.h b/src/models/llm/llama/llama3.h
index 644b9dd..f595cea 100644
--- a/src/models/llm/llama/llama3.h
+++ b/src/models/llm/llama/llama3.h
@@ -88,6 +88,13 @@ class Llama3 : public jinq::models::BaseAiModel {
*/
void clear_kv_cache_cell() const;
+ /***
+ *
+ * @param prompt
+ * @return
+ */
+ int32_t count_prompt_token_nums(const std::string& prompt) const;
+
/***
* if model successfully initialized
* @return
@@ -98,6 +105,7 @@ class Llama3 : public jinq::models::BaseAiModel {
class Impl;
std::unique_ptr _m_pimpl;
};
+
}
}
}
diff --git a/src/models/llm/llama/llama3.inl b/src/models/llm/llama/llama3.inl
index 2023736..7d0af47 100644
--- a/src/models/llm/llama/llama3.inl
+++ b/src/models/llm/llama/llama3.inl
@@ -136,6 +136,8 @@ public:
*/
void clear_kv_cache_cell() const;
+ int32_t count_prompt_token_nums(const std::string& prompt) const;
+
/***
*
* @return
@@ -344,6 +346,22 @@ void Llama3::Impl::clear_kv_cache_cell() const {
llama_kv_cache_clear(_m_ctx);
}
+/***
+ *
+ * @tparam INPUT
+ * @tparam OUTPUT
+ * @param prompt
+ * @return
+ */
+template
+int32_t Llama3::Impl::count_prompt_token_nums(const std::string &prompt) const {
+ if (prompt.empty()) {
+ return 0;
+ }
+ auto n_prompt_tokens = llama_tokenize(_m_model, prompt.c_str(), static_cast(prompt.size()), nullptr, 0, true, true);
+ return -n_prompt_tokens;
+}
+
/***
*
* @tparam INPUT
@@ -362,7 +380,7 @@ StatusCode Llama3::Impl::llama_generate(std::vector&
int n_ctx_used = llama_get_kv_cache_used_cells(_m_ctx);
if (n_ctx_used + batch.n_tokens > n_ctx) {
LOG(ERROR) << "context size exceeded";
- return StatusCode::MODEL_RUN_SESSION_FAILED;
+ return StatusCode::LLM_CONTEXT_SIZE_EXCEEDED;
}
auto status = llama_decode(_m_ctx, batch);
@@ -489,6 +507,18 @@ void Llama3::clear_kv_cache_cell() const {
return _m_pimpl->clear_kv_cache_cell();
}
+/***
+ *
+ * @tparam INPUT
+ * @tparam OUTPUT
+ * @param prompt
+ * @return
+ */
+template
+int32_t Llama3::count_prompt_token_nums(const std::string &prompt) const {
+ return _m_pimpl->count_prompt_token_nums(prompt);
+}
+
}
}
}