Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Dec 5, 2024
1 parent b9ba45d commit 16c5da3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/models/llm/llama/llama3.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ class Llama3 : public jinq::models::BaseAiModel<INPUT, OUTPUT> {
*/
void clear_kv_cache_cell() const;

/***
*
* @param prompt
* @return
*/
int32_t count_prompt_token_nums(const std::string& prompt) const;

/***
* if model successfully initialized
* @return
Expand All @@ -98,6 +105,7 @@ class Llama3 : public jinq::models::BaseAiModel<INPUT, OUTPUT> {
class Impl;
std::unique_ptr<Impl> _m_pimpl;
};

}
}
}
Expand Down
32 changes: 31 additions & 1 deletion src/models/llm/llama/llama3.inl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ public:
*/
void clear_kv_cache_cell() const;

int32_t count_prompt_token_nums(const std::string& prompt) const;

/***
*
* @return
Expand Down Expand Up @@ -344,6 +346,22 @@ void Llama3<INPUT, OUTPUT>::Impl::clear_kv_cache_cell() const {
llama_kv_cache_clear(_m_ctx);
}

/***
*
* @tparam INPUT
* @tparam OUTPUT
* @param prompt
* @return
*/
template <typename INPUT, typename OUTPUT>
int32_t Llama3<INPUT, OUTPUT>::Impl::count_prompt_token_nums(const std::string &prompt) const {
if (prompt.empty()) {
return 0;
}
auto n_prompt_tokens = llama_tokenize(_m_model, prompt.c_str(), static_cast<int32_t>(prompt.size()), nullptr, 0, true, true);
return -n_prompt_tokens;
}

/***
*
* @tparam INPUT
Expand All @@ -362,7 +380,7 @@ StatusCode Llama3<INPUT, OUTPUT>::Impl::llama_generate(std::vector<llama_token>&
int n_ctx_used = llama_get_kv_cache_used_cells(_m_ctx);
if (n_ctx_used + batch.n_tokens > n_ctx) {
LOG(ERROR) << "context size exceeded";
return StatusCode::MODEL_RUN_SESSION_FAILED;
return StatusCode::LLM_CONTEXT_SIZE_EXCEEDED;
}

auto status = llama_decode(_m_ctx, batch);
Expand Down Expand Up @@ -489,6 +507,18 @@ void Llama3<INPUT, OUTPUT>::clear_kv_cache_cell() const {
return _m_pimpl->clear_kv_cache_cell();
}

/***
*
* @tparam INPUT
* @tparam OUTPUT
* @param prompt
* @return
*/
template <typename INPUT, typename OUTPUT>
int32_t Llama3<INPUT, OUTPUT>::count_prompt_token_nums(const std::string &prompt) const {
return _m_pimpl->count_prompt_token_nums(prompt);
}

}
}
}
Expand Down

0 comments on commit 16c5da3

Please sign in to comment.