diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index a1fc672a1a..6face83f02 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -23,19 +23,30 @@ #include "debug_utils.hpp" +namespace { + +enum class GenerationChatInputsType { + UNDEF = 0, // Default value, type of inputs is not defined + STRING = 1, // Type of inputs is StringInputs + ENCODED_INPUTS = 2, // Type of inputs is EncodedInputs +}; + +} // namespace + namespace ov { namespace genai { class StatefulLLMPipeline final : public LLMPipelineImplBase { public: ov::InferRequest m_model_runner; - bool is_chat_conversation = false; bool m_history_available = false; - std::optional m_selected_beam = std::nullopt; + bool m_is_cache_empty = true; ChatHistory m_history; std::string m_templated_chat_history = {}; std::vector m_tokenized_chat_history; + GenerationChatInputsType m_chat_input_type = GenerationChatInputsType::UNDEF; + std::optional m_last_disappeared_token = std::nullopt; StatefulLLMPipeline( const ov::InferRequest& request, @@ -87,6 +98,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { OptionalGenerationConfig generation_config, StreamerVariant streamer ) override { + if (is_chat_conversation && m_chat_input_type == GenerationChatInputsType::UNDEF) + m_chat_input_type = GenerationChatInputsType::STRING; + + if (is_chat_conversation) + OPENVINO_ASSERT(m_chat_input_type != GenerationChatInputsType::ENCODED_INPUTS, + "Chat doesn't support switching between input types. Please, continue using EncodedInputs or restart the chat."); + auto start_time = std::chrono::steady_clock::now(); GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; TokenizedInputs encoded_input; @@ -119,6 +137,16 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { encoded_input = new_chat_tokens; } m_templated_chat_history = new_templated_chat_history; + + m_tokenized_chat_history.clear(); + std::copy(new_chat_tokens.input_ids.data(), new_chat_tokens.input_ids.data() + new_chat_tokens.input_ids.get_size(), + std::back_inserter(m_tokenized_chat_history)); + + // no need to add m_last_disappeared_token to encoded_input, it was kept by subtract_chat_tokenized_inputs + if (m_last_disappeared_token.has_value() && *m_last_disappeared_token == encoded_input.input_ids.data()[0]) { + m_last_disappeared_token = std::nullopt; + } + // TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied } else { encoded_input = m_tokenizer.encode(prompt); @@ -172,6 +200,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { OptionalGenerationConfig generation_config, StreamerVariant streamer ) override { + if (is_chat_conversation && m_chat_input_type == GenerationChatInputsType::UNDEF) + m_chat_input_type = GenerationChatInputsType::ENCODED_INPUTS; + + if (is_chat_conversation) + // if chat was run in StringInputs mode, but it was called EncodedInputs generate, last m_history entry will be with assistant role + OPENVINO_ASSERT(m_chat_input_type == GenerationChatInputsType::ENCODED_INPUTS || m_history.back()["role"] == "user", + "Chat doesn't support switching between input types. Please, continue using StringInputs or restart the chat."); + auto start_time = std::chrono::steady_clock::now(); ov::Tensor input_ids; ov::Tensor attention_mask; @@ -183,6 +219,16 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { attention_mask = data->attention_mask; } + if (is_chat_conversation && m_chat_input_type == GenerationChatInputsType::ENCODED_INPUTS) { + std::copy(input_ids.data(), input_ids.data() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history)); + } + + // Tail of previous output in chat mode is missing in KV cache. + if (is_chat_conversation && m_last_disappeared_token.has_value()) { + attention_mask = ov::genai::utils::push_front_inputs(attention_mask, std::vector{1}); + input_ids = ov::genai::utils::push_front_inputs(input_ids, std::vector{*m_last_disappeared_token}); + } + GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; // If eos_token_id was not provided, take value from default m_generation_config @@ -214,16 +260,11 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { "(input_ids, attention_mask, position_ids, beam_idx) " "but you have '" + std::to_string(num_inputs) + "' inputs"); - if (is_chat_conversation) { - std::copy(input_ids.data(), input_ids.data() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history)); - } ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()); - bool is_cache_empty = !m_selected_beam.has_value(); - size_t kv_cache_len = 0; ov::Tensor concatenated_attention_mask; if (is_chat_conversation && m_history_available) { - if (is_cache_empty) { + if (m_is_cache_empty) { attention_mask = ov::genai::utils::init_attention_mask(tokenized_chat_history); concatenated_attention_mask = attention_mask; } else { @@ -249,7 +290,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { bool position_ids_available = (num_inputs == 4); std::optional position_ids = std::nullopt; if (position_ids_available) { - if (is_chat_conversation && is_cache_empty) { + if (is_chat_conversation && m_is_cache_empty) { position_ids = ov::Tensor{ov::element::i64, tokenized_chat_history.get_shape()}; } else { position_ids = ov::Tensor{ov::element::i64, input_ids.get_shape()}; @@ -261,7 +302,6 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { m_adapter_controller->apply(m_model_runner, config.adapters); } - ov::genai::EncodedResults result; std::vector requests; size_t block_size = 1; bool enable_prefix_caching = false; @@ -285,25 +325,25 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { Sampler sampler = Sampler(m_tokenizer); // we can't properly refer to history in case of chat scenario with beam search, so reset_kv_state and use the whole history for each new propmt auto input_tokens = input_ids; - if (is_chat_conversation && is_cache_empty) { + if (is_chat_conversation && m_is_cache_empty) { input_tokens = tokenized_chat_history; } - result = ov::genai::get_lm_encoded_results(m_model_runner, input_tokens, concatenated_attention_mask, streamer_ptr, sampler, requests, position_ids, std::nullopt); + ov::genai::EncodedResults result = ov::genai::get_lm_encoded_results(m_model_runner, input_tokens, concatenated_attention_mask, + streamer_ptr, sampler, requests, position_ids, std::nullopt); - m_selected_beam = 0; + m_is_cache_empty = false; if (!is_chat_conversation || config.is_beam_search()) { reset_kv_state(); - m_selected_beam = std::nullopt; + m_is_cache_empty = true; } if (is_chat_conversation) { m_history_available = true; + m_last_disappeared_token = result.tokens[0].back(); } - if (is_chat_conversation) { - auto decoded_result = m_tokenizer.decode(result.tokens[0]); - auto answer = m_tokenizer.encode(decoded_result, ov::genai::add_special_tokens(false)).input_ids; - std::copy(answer.data(), answer.data() + answer.get_size(), std::back_inserter(m_tokenized_chat_history)); + if (is_chat_conversation && m_chat_input_type == GenerationChatInputsType::ENCODED_INPUTS) { + std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history)); } auto stop_time = std::chrono::steady_clock::now(); @@ -319,7 +359,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void start_chat(const std::string& system_message) override { is_chat_conversation = true; - m_selected_beam = std::nullopt; + m_is_cache_empty = true; if (m_history_available) { reset_kv_state(); m_history_available = false; @@ -338,7 +378,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void finish_chat() override { is_chat_conversation = false; - m_selected_beam = std::nullopt; + m_is_cache_empty = true; if (m_history_available) { reset_kv_state(); m_history_available = false; diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index 15b497ca76..6d234aa07e 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -9,12 +9,11 @@ #include #include +#include "utils.hpp" +#include "debug_utils.hpp" #include "lm_encoding.hpp" #include "openvino/genai/perf_metrics.hpp" -#include "utils.hpp" - -#include "debug_utils.hpp" namespace ov { namespace genai { diff --git a/src/cpp/src/lm_encoding.hpp b/src/cpp/src/lm_encoding.hpp index f95eae1436..0a342f0a37 100644 --- a/src/cpp/src/lm_encoding.hpp +++ b/src/cpp/src/lm_encoding.hpp @@ -12,9 +12,5 @@ EncodedResults get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& const std::shared_ptr& streamer_ptr, Sampler& sampler, std::vector sequence_groups, std::optional position_ids, std::optional m_embedding); -void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector next_beams); - -void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask); - } } diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 50c2e0c49e..0ef182829f 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -266,6 +266,14 @@ ov::Core singleton_core() { return core; } +ov::Tensor push_front_inputs(const ov::Tensor& base_tensor, std::vector add_to_front) { + ov::Tensor new_tensor = ov::Tensor{ov::element::i64, {base_tensor.get_shape().at(0), base_tensor.get_shape().at(1) + add_to_front.size()}}; + auto new_tensor_data = new_tensor.data(); + std::copy(add_to_front.begin(), add_to_front.end(), new_tensor_data); + std::copy(base_tensor.data(), base_tensor.data() + base_tensor.get_size(), new_tensor_data + add_to_front.size()); + return new_tensor; +} + } // namespace utils } // namespace genai } // namespace ov diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 3487fccb81..e807a0bbbf 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -86,6 +86,8 @@ void slice_matmul_statefull_model(std::shared_ptr model); ov::Core singleton_core(); +ov::Tensor push_front_inputs(const ov::Tensor& base_tensor, std::vector add_to_front); + } // namespace utils } // namespace genai } // namespace ov