Skip to content

Commit

Permalink
[LLAMA_CPP] Use separate states for infer requests (#908)
Browse files Browse the repository at this point in the history
* Fix previous test

* Add test

* Implement separate states
  • Loading branch information
vshampor authored May 7, 2024
1 parent a6b9f14 commit bb43dc1
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 23 deletions.
1 change: 1 addition & 0 deletions modules/llama_cpp_plugin/include/compiled_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class LlamaCppModel : public ICompiledModel {
private:
gguf_context* m_gguf_ctx = nullptr;
std::string m_gguf_fname;
size_t m_num_threads;

llama_model* m_llama_model_ptr = nullptr;
llama_context* m_llama_ctx = nullptr;
Expand Down
5 changes: 3 additions & 2 deletions modules/llama_cpp_plugin/include/infer_request.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ namespace llama_cpp_plugin {

class LlamaCppSyncInferRequest : public ISyncInferRequest {
public:
explicit LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model);
virtual ~LlamaCppSyncInferRequest(){};
explicit LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model, size_t num_threads);
virtual ~LlamaCppSyncInferRequest() override;

virtual void set_tensors_impl(const ov::Output<const ov::Node> port,
const std::vector<ov::SoPtr<ov::ITensor>>& tensors) override;
Expand All @@ -24,6 +24,7 @@ class LlamaCppSyncInferRequest : public ISyncInferRequest {

private:
std::shared_ptr<const LlamaCppModel> m_compiled_model_ptr;
llama_context* m_llama_ctx;
};

} // namespace llama_cpp_plugin
Expand Down
9 changes: 5 additions & 4 deletions modules/llama_cpp_plugin/include/state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ namespace llama_cpp_plugin {
class LlamaCppState : public IVariableState {
public:
LlamaCppState() = delete;
LlamaCppState(const std::shared_ptr<const LlamaCppModel>& model_ptr)
: m_model_ptr(model_ptr),
LlamaCppState(llama_context* llama_context_ptr)
: m_llama_ctx_ptr(llama_context_ptr),
IVariableState("llama_cpp_state") {}
void reset() override {
llama_kv_cache_clear(m_model_ptr->m_llama_ctx);
OPENVINO_ASSERT(m_llama_ctx_ptr != nullptr);
llama_kv_cache_clear(m_llama_ctx_ptr);
}

private:
const std::shared_ptr<const LlamaCppModel>& m_model_ptr;
llama_context* m_llama_ctx_ptr;
};
} // namespace llama_cpp_plugin
} // namespace ov
Expand Down
13 changes: 4 additions & 9 deletions modules/llama_cpp_plugin/src/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <openvino/opsets/opset13.hpp>
#include <openvino/runtime/properties.hpp>
#include <openvino/util/log.hpp>
#include <thread>

#include "infer_request.hpp"
#include "plugin.hpp"
Expand All @@ -18,7 +17,6 @@ namespace ov {
namespace llama_cpp_plugin {

LlamaCppModel::~LlamaCppModel() {
llama_free(m_llama_ctx);
llama_free_model(m_llama_model_ptr);
llama_backend_free();
}
Expand All @@ -27,15 +25,12 @@ LlamaCppModel::LlamaCppModel(const std::string& gguf_fname,
const std::shared_ptr<const IPlugin>& plugin,
size_t num_threads)
: ICompiledModel(nullptr, plugin),
m_gguf_fname(gguf_fname) {
m_gguf_fname(gguf_fname),
m_num_threads(num_threads) {
OPENVINO_DEBUG << "llama_cpp_plugin: loading llama model directly from GGUF... " << std::endl;
llama_model_params mparams = llama_model_default_params();
mparams.n_gpu_layers = 99;
m_llama_model_ptr = llama_load_model_from_file(gguf_fname.c_str(), mparams);
llama_context_params cparams = llama_context_default_params();
cparams.n_threads = num_threads ? num_threads : std::thread::hardware_concurrency();
cparams.n_ctx = 0; // this means that the actual n_ctx will be taken equal to the model's train-time value
m_llama_ctx = llama_new_context_with_model(m_llama_model_ptr, cparams);
OPENVINO_DEBUG << "llama_cpp_plugin: llama model loaded successfully from GGUF..." << std::endl;

auto input_ids = std::make_shared<ov::opset13::Parameter>(ov::element::Type_t::i64, ov::PartialShape({-1, -1}));
Expand Down Expand Up @@ -87,8 +82,8 @@ ov::Any LlamaCppModel::get_property(const std::string& name) const {
}

std::shared_ptr<ov::ISyncInferRequest> LlamaCppModel::create_sync_infer_request() const {
return std::make_shared<LlamaCppSyncInferRequest>(
std::static_pointer_cast<const LlamaCppModel>(shared_from_this()));
return std::make_shared<LlamaCppSyncInferRequest>(std::static_pointer_cast<const LlamaCppModel>(shared_from_this()),
m_num_threads);
}

const std::vector<ov::Output<const ov::Node>>& LlamaCppModel::inputs() const {
Expand Down
21 changes: 16 additions & 5 deletions modules/llama_cpp_plugin/src/infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <memory>
#include <openvino/runtime/ivariable_state.hpp>
#include <thread>

#include "llama.h"
#include "openvino/runtime/make_tensor.hpp"
Expand All @@ -24,9 +25,14 @@ void allocate_tensor_impl(ov::SoPtr<ov::ITensor>& tensor,
}
}

LlamaCppSyncInferRequest::LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model)
LlamaCppSyncInferRequest::LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model,
size_t num_threads)
: ov::ISyncInferRequest(compiled_model) {
OPENVINO_DEBUG << "llama_cpp_plugin: infer request ctor called\n";
llama_context_params cparams = llama_context_default_params();
cparams.n_threads = num_threads ? num_threads : std::thread::hardware_concurrency();
cparams.n_ctx = 0; // this means that the actual n_ctx will be taken equal to the model's train-time value
m_llama_ctx = llama_new_context_with_model(compiled_model->m_llama_model_ptr, cparams);
m_compiled_model_ptr = compiled_model;
for (const auto& input : get_inputs()) {
allocate_tensor(input, [input](ov::SoPtr<ov::ITensor>& tensor) {
Expand Down Expand Up @@ -97,8 +103,7 @@ void LlamaCppSyncInferRequest::infer() {
}
}

llama_context* ctx = m_compiled_model_ptr->m_llama_ctx;
int32_t sts = llama_decode(ctx, batch);
int32_t sts = llama_decode(m_llama_ctx, batch);

if (sts != 0) {
OPENVINO_THROW("llama_decode failed with code ", sts);
Expand All @@ -112,7 +117,7 @@ void LlamaCppSyncInferRequest::infer() {
for (size_t batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (size_t seq_idx = 0; seq_idx < sequence_length; seq_idx++) {
size_t pos = batch_idx * sequence_length + seq_idx;
float* logits_from_llama = llama_get_logits_ith(ctx, pos);
float* logits_from_llama = llama_get_logits_ith(m_llama_ctx, pos);
std::copy(logits_from_llama, logits_from_llama + n_vocab, output_tensor_data_ptr + pos * n_vocab);
}
}
Expand All @@ -132,7 +137,13 @@ std::vector<ov::ProfilingInfo> LlamaCppSyncInferRequest::get_profiling_info() co

std::vector<ov::SoPtr<ov::IVariableState>> LlamaCppSyncInferRequest::query_state() const {
OPENVINO_DEBUG << "llama_cpp_plugin: query_state() called\n";
return {std::static_pointer_cast<ov::IVariableState>(std::make_shared<LlamaCppState>(m_compiled_model_ptr))};
return {std::static_pointer_cast<ov::IVariableState>(std::make_shared<LlamaCppState>(m_llama_ctx))};
}

LlamaCppSyncInferRequest::~LlamaCppSyncInferRequest() {
if (m_llama_ctx != nullptr) {
llama_free(m_llama_ctx);
}
}
} // namespace llama_cpp_plugin
} // namespace ov
45 changes: 42 additions & 3 deletions modules/llama_cpp_plugin/tests/functional/src/reset_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,55 @@ TEST_F(CompiledModelTest, ResetStateGPT2) {
SetUp();

ov::InferRequest lm_bad = model.create_infer_request();
std::vector<float> logits_lennon_bad = infer_and_get_last_logits(lm, GPT2_LENNON_PROMPT_TOKEN_IDS, 0);
std::vector<float> logits_lennon_bad = infer_and_get_last_logits(lm_bad, GPT2_LENNON_PROMPT_TOKEN_IDS, 0);

// no reset_state on purpose

std::vector<float> logits_sun_bad = infer_and_get_last_logits(lm_reset,
std::vector<float> logits_sun_bad = infer_and_get_last_logits(lm_bad,
GPT2_SUN_PROMPT_TOKEN_IDS,
0); // GPT2_LENNON_PROMPT_TOKEN_IDS.size());

std::vector<int64_t> out_token_ids_bad = generate_n_tokens_with_positions(lm_reset,
std::vector<int64_t> out_token_ids_bad = generate_n_tokens_with_positions(lm_bad,
get_token_from_logits(logits_sun_reset),
NUM_TOKENS_TO_GENERATE,
GPT2_SUN_PROMPT_TOKEN_IDS.size());
ASSERT_NE(out_token_ids_bad, out_token_ids_ref);
}

TEST_F(CompiledModelTest, StatesForDifferentInferRequestsAreIndependentGPT2) {
// Take two infer requests, process two different prompts with same position IDs, but for one of them, do
// .reset_state() in-between the inferences - check that the state is reset independently.

// the "new" sequence should have the same number of tokens as the previous one for this to work
std::vector<int64_t> MODIFIED_PROMPT_TOKEN_IDS = GPT2_LENNON_PROMPT_TOKEN_IDS;
MODIFIED_PROMPT_TOKEN_IDS.push_back(30); // extra newline
ASSERT_EQ(GPT2_SUN_PROMPT_TOKEN_IDS.size(), MODIFIED_PROMPT_TOKEN_IDS.size());

ov::InferRequest first_infer_request = model.create_infer_request();
std::vector<float> logits_first_ref = infer_and_get_last_logits(first_infer_request, GPT2_SUN_PROMPT_TOKEN_IDS, 0);

ov::InferRequest another_infer_request = model.create_infer_request();
std::vector<float> logits_another_ref =
infer_and_get_last_logits(another_infer_request, GPT2_SUN_PROMPT_TOKEN_IDS, 0);

first_infer_request.reset_state();

std::vector<float> logits_first_new_tokens_old_positions =
infer_and_get_last_logits(first_infer_request, MODIFIED_PROMPT_TOKEN_IDS, 0);
std::vector<int64_t> out_tokens_first =
generate_n_tokens_with_positions(first_infer_request,
get_token_from_logits(logits_first_new_tokens_old_positions),
NUM_TOKENS_TO_GENERATE,
MODIFIED_PROMPT_TOKEN_IDS.size());

// not resetting another_infer_request state on purpose
std::vector<float> logits_another_new_tokens_old_positions =
infer_and_get_last_logits(another_infer_request, MODIFIED_PROMPT_TOKEN_IDS, 0);
std::vector<int64_t> out_tokens_another =
generate_n_tokens_with_positions(another_infer_request,
get_token_from_logits(logits_another_new_tokens_old_positions),
NUM_TOKENS_TO_GENERATE,
MODIFIED_PROMPT_TOKEN_IDS.size());

EXPECT_NE(out_tokens_another, out_tokens_first);
}

0 comments on commit bb43dc1

Please sign in to comment.