From 11ac0a2460caff13aa4ac56f53d428589428ce78 Mon Sep 17 00:00:00 2001 From: sbalandi Date: Fri, 15 Nov 2024 18:48:13 +0000 Subject: [PATCH] chat history update --- src/cpp/src/llm_pipeline.cpp | 36 ++++++++++++++++++++++++++++-------- src/cpp/src/lm_encoding.cpp | 24 ++++-------------------- 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 064e314ae9..174b138324 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -32,7 +32,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { bool m_is_cache_empty = true; ChatHistory m_history; std::string m_templated_chat_history = {}; - TokenizedInputs m_tokenized_chat_history; + ov::Tensor m_tokenized_chat_history = ov::Tensor(ov::element::i64, {0, 0}); StatefulLLMPipeline( const ov::InferRequest& request, @@ -116,7 +116,6 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens); } m_templated_chat_history = new_templated_chat_history; - m_tokenized_chat_history = new_chat_tokens; // TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied } else { encoded_input = m_tokenizer.encode(prompt); @@ -234,13 +233,23 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { concatenated_attention_mask = attention_mask; } + if (is_chat_conversation) { + ov::Tensor new_tokenized_chat_history = ov::Tensor{ov::element::i64, {batch_size, m_tokenized_chat_history.get_shape().at(1) + input_ids.get_shape().at(1)}}; + auto start_chat_hst = m_tokenized_chat_history.data(); + std::copy(start_chat_hst, start_chat_hst + m_tokenized_chat_history.get_size(), new_tokenized_chat_history.data()); + std::copy(input_ids.data(), input_ids.data() + input_ids.get_size(), + new_tokenized_chat_history.data() + m_tokenized_chat_history.get_size()); + + m_tokenized_chat_history = new_tokenized_chat_history; + } + bool position_ids_available = (num_inputs == 4); std::optional position_ids = std::nullopt; if (position_ids_available) { if (is_chat_conversation && config.is_beam_search()) { - position_ids = ov::Tensor{ov::element::i64, m_tokenized_chat_history.input_ids.get_shape()}; - size_t start_pos = kv_cache_len - (m_tokenized_chat_history.input_ids.get_shape().at(1) - input_ids.get_shape().at(1)); - size_t seq_length = m_tokenized_chat_history.input_ids.get_shape().at(1); + position_ids = ov::Tensor{ov::element::i64, m_tokenized_chat_history.get_shape()}; + size_t start_pos = kv_cache_len - (m_tokenized_chat_history.get_shape().at(1) - input_ids.get_shape().at(1)); + size_t seq_length = m_tokenized_chat_history.get_shape().at(1); utils::initialize_position_ids(*position_ids, concatenated_attention_mask, seq_length, start_pos); } else { @@ -261,8 +270,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { for (size_t request_id = 0; request_id < batch_size; request_id++) { SequenceGroup::Ptr sequence_group; if (is_chat_conversation) { - sequence_group = std::make_shared(request_id, m_tokenized_chat_history.input_ids, config, block_size, enable_prefix_caching); - sequence_group->update_processed_tokens_num(m_tokenized_chat_history.input_ids.get_shape().at(1) - 1); + sequence_group = std::make_shared(request_id, m_tokenized_chat_history, config, block_size, enable_prefix_caching); + sequence_group->update_processed_tokens_num(m_tokenized_chat_history.get_shape().at(1) - 1); } else { size_t seq_len = input_ids.get_shape().at(1); size_t batch_offset = request_id * seq_len; @@ -282,7 +291,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { Sampler sampler = Sampler(m_tokenizer); // we can't properly refer to history in case of chat scenario with beam search, so reset_kv_state and use the whole history for each new propmt - auto input_tokens = (is_chat_conversation && config.is_beam_search()) ? m_tokenized_chat_history.input_ids : input_ids; + auto input_tokens = (is_chat_conversation && config.is_beam_search()) ? m_tokenized_chat_history : input_ids; result = ov::genai::get_lm_encoded_results(m_model_runner, input_tokens, concatenated_attention_mask, streamer_ptr, sampler, requests, position_ids, std::nullopt); if (!is_chat_conversation || config.is_beam_search()) { @@ -291,6 +300,15 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { if (is_chat_conversation) { m_is_cache_empty = false; + + // remove eos token, if it is at the end + auto last_token = result.tokens[0].back() == config.eos_token_id ? result.tokens[0].size() - 1 : result.tokens[0].size(); + ov::Tensor new_tokenized_chat_history = ov::Tensor{ov::element::i64, {batch_size, m_tokenized_chat_history.get_shape().at(1) + last_token}}; + auto start_chat_hst = m_tokenized_chat_history.data(); + std::copy(start_chat_hst, start_chat_hst + m_tokenized_chat_history.get_size(), new_tokenized_chat_history.data()); + std::copy(result.tokens[0].begin(), result.tokens[0].begin() + last_token, new_tokenized_chat_history.data() + m_tokenized_chat_history.get_size()); + + m_tokenized_chat_history = new_tokenized_chat_history; } auto stop_time = std::chrono::steady_clock::now(); @@ -311,6 +329,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { m_is_cache_empty = true; m_history = {}; m_templated_chat_history = ""; + m_tokenized_chat_history = ov::Tensor(ov::element::i64, {0, 0}); } if (system_message.empty()) return; @@ -328,6 +347,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { m_is_cache_empty = true; m_history.clear(); m_templated_chat_history.clear(); + m_tokenized_chat_history = ov::Tensor(ov::element::i64, {0, 0}); } } }; diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index 912f4ffd0d..d0ab0a4f3c 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -19,15 +19,6 @@ namespace ov { namespace genai { -void reset_all_inputs_to_empty_tensors(ov::InferRequest& request, bool is_vlm = false) { - if (!is_vlm) - request.set_tensor("input_ids", ov::Tensor(ov::element::i64, {0, 0})); - - request.set_tensor("beam_idx", ov::Tensor(ov::element::i32, {0})); - if (request.get_compiled_model().inputs().size() == 4) - request.set_tensor("position_ids", ov::Tensor(ov::element::i64, {0, 0})); -} - void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask) { const size_t batch_size = attention_mask.get_shape().at(0); const size_t sequence_length = attention_mask.get_shape().at(1); @@ -173,15 +164,11 @@ EncodedResults get_lm_encoded_results( if (m_embedding.has_value()) { const ov::Tensor& embed_prompt_tensor = (*m_embedding).infer(new_input_ids); - - m_llm.get_tensor("inputs_embeds").set_shape(embed_prompt_tensor.get_shape()); m_llm.set_tensor("inputs_embeds", embed_prompt_tensor); } else { - m_llm.get_tensor("input_ids").set_shape({total_num_tokens, 1}); m_llm.set_tensor("input_ids", new_input_ids); } - m_llm.get_tensor("beam_idx").set_shape({ total_num_tokens }); m_llm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()}); update_attention_mask_with_beams(m_llm.get_tensor("attention_mask"), next_beams); @@ -221,18 +208,15 @@ EncodedResults get_lm_encoded_results( streamer_ptr->end(); } - reset_all_inputs_to_empty_tensors(m_llm, m_embedding.has_value()); - - for (size_t i = 0; i < sequence_groups.size(); i++) { - auto request = sequence_groups[i]; + for (auto& sequence_group : sequence_groups) { // sequences is sorted by cumulative_log_prob with length_penalty - auto outputs = request->get_finished_sequences(); + auto outputs = sequence_group->get_finished_sequences(); - auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, outputs.size()); + auto num_outputs = std::min(sequence_group->get_sampling_parameters().num_return_sequences, outputs.size()); for (size_t output_idx = 0; output_idx < num_outputs; ++output_idx) { const auto& output = outputs[output_idx]; results.tokens.push_back(std::move(output->get_generated_ids())); - results.scores.push_back(output->get_cumulative_score_with_length_penalty(request->get_sampling_parameters())); + results.scores.push_back(output->get_cumulative_score_with_length_penalty(sequence_group->get_sampling_parameters())); } }