diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index ca01e8e11c..a1fc672a1a 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -31,7 +31,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { ov::InferRequest m_model_runner; bool is_chat_conversation = false; - bool m_is_cache_empty = true; + bool m_history_available = false; std::optional m_selected_beam = std::nullopt; ChatHistory m_history; std::string m_templated_chat_history = {}; @@ -112,11 +112,11 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); // Do not add special tokens in chat scenario to be aligned with HF. auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false)); - if (m_is_cache_empty) { - encoded_input = new_chat_tokens; - } else { + if (m_history_available) { auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false)); encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens); + } else { + encoded_input = new_chat_tokens; } m_templated_chat_history = new_templated_chat_history; // TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied @@ -218,12 +218,15 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { std::copy(input_ids.data(), input_ids.data() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history)); } ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()); - bool kv_history_available = m_selected_beam.has_value(); + bool is_cache_empty = !m_selected_beam.has_value(); size_t kv_cache_len = 0; ov::Tensor concatenated_attention_mask; - if (is_chat_conversation && !m_is_cache_empty) { - if (kv_history_available) { + if (is_chat_conversation && m_history_available) { + if (is_cache_empty) { + attention_mask = ov::genai::utils::init_attention_mask(tokenized_chat_history); + concatenated_attention_mask = attention_mask; + } else { OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1"); // If history is saved in KV cache, concatenate new attention_mask with the already existing. // Between subsequent runs attention_mask should not be modified. @@ -238,9 +241,6 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { std::copy(attention_mask.data(), attention_mask.data() + prompt_len, new_atten_mask.data() + kv_cache_len); concatenated_attention_mask = new_atten_mask; - } else { - attention_mask = ov::genai::utils::init_attention_mask(tokenized_chat_history); - concatenated_attention_mask = attention_mask; } } else { concatenated_attention_mask = attention_mask; @@ -249,7 +249,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { bool position_ids_available = (num_inputs == 4); std::optional position_ids = std::nullopt; if (position_ids_available) { - if (is_chat_conversation && !kv_history_available) { + if (is_chat_conversation && is_cache_empty) { position_ids = ov::Tensor{ov::element::i64, tokenized_chat_history.get_shape()}; } else { position_ids = ov::Tensor{ov::element::i64, input_ids.get_shape()}; @@ -285,7 +285,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { Sampler sampler = Sampler(m_tokenizer); // we can't properly refer to history in case of chat scenario with beam search, so reset_kv_state and use the whole history for each new propmt auto input_tokens = input_ids; - if (is_chat_conversation && !kv_history_available) { + if (is_chat_conversation && is_cache_empty) { input_tokens = tokenized_chat_history; } result = ov::genai::get_lm_encoded_results(m_model_runner, input_tokens, concatenated_attention_mask, streamer_ptr, sampler, requests, position_ids, std::nullopt); @@ -297,7 +297,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { } if (is_chat_conversation) { - m_is_cache_empty = false; + m_history_available = true; } if (is_chat_conversation) { @@ -320,9 +320,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void start_chat(const std::string& system_message) override { is_chat_conversation = true; m_selected_beam = std::nullopt; - if (!m_is_cache_empty) { + if (m_history_available) { reset_kv_state(); - m_is_cache_empty = true; + m_history_available = false; m_history = {}; m_templated_chat_history = ""; m_tokenized_chat_history = {}; @@ -339,9 +339,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void finish_chat() override { is_chat_conversation = false; m_selected_beam = std::nullopt; - if (!m_is_cache_empty) { + if (m_history_available) { reset_kv_state(); - m_is_cache_empty = true; + m_history_available = false; m_history.clear(); m_templated_chat_history.clear(); m_tokenized_chat_history = {};