Skip to content

Commit

Permalink
chat history update
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Nov 15, 2024
1 parent c7b9843 commit 11ac0a2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 28 deletions.
36 changes: 28 additions & 8 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
bool m_is_cache_empty = true;
ChatHistory m_history;
std::string m_templated_chat_history = {};
TokenizedInputs m_tokenized_chat_history;
ov::Tensor m_tokenized_chat_history = ov::Tensor(ov::element::i64, {0, 0});

StatefulLLMPipeline(
const ov::InferRequest& request,
Expand Down Expand Up @@ -116,7 +116,6 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
}
m_templated_chat_history = new_templated_chat_history;
m_tokenized_chat_history = new_chat_tokens;
// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
} else {
encoded_input = m_tokenizer.encode(prompt);
Expand Down Expand Up @@ -234,13 +233,23 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
concatenated_attention_mask = attention_mask;
}

if (is_chat_conversation) {
ov::Tensor new_tokenized_chat_history = ov::Tensor{ov::element::i64, {batch_size, m_tokenized_chat_history.get_shape().at(1) + input_ids.get_shape().at(1)}};
auto start_chat_hst = m_tokenized_chat_history.data<int64_t>();
std::copy(start_chat_hst, start_chat_hst + m_tokenized_chat_history.get_size(), new_tokenized_chat_history.data<int64_t>());
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(),
new_tokenized_chat_history.data<int64_t>() + m_tokenized_chat_history.get_size());

m_tokenized_chat_history = new_tokenized_chat_history;
}

bool position_ids_available = (num_inputs == 4);
std::optional<ov::Tensor> position_ids = std::nullopt;
if (position_ids_available) {
if (is_chat_conversation && config.is_beam_search()) {
position_ids = ov::Tensor{ov::element::i64, m_tokenized_chat_history.input_ids.get_shape()};
size_t start_pos = kv_cache_len - (m_tokenized_chat_history.input_ids.get_shape().at(1) - input_ids.get_shape().at(1));
size_t seq_length = m_tokenized_chat_history.input_ids.get_shape().at(1);
position_ids = ov::Tensor{ov::element::i64, m_tokenized_chat_history.get_shape()};
size_t start_pos = kv_cache_len - (m_tokenized_chat_history.get_shape().at(1) - input_ids.get_shape().at(1));
size_t seq_length = m_tokenized_chat_history.get_shape().at(1);

utils::initialize_position_ids(*position_ids, concatenated_attention_mask, seq_length, start_pos);
} else {
Expand All @@ -261,8 +270,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation) {
sequence_group = std::make_shared<SequenceGroup>(request_id, m_tokenized_chat_history.input_ids, config, block_size, enable_prefix_caching);
sequence_group->update_processed_tokens_num(m_tokenized_chat_history.input_ids.get_shape().at(1) - 1);
sequence_group = std::make_shared<SequenceGroup>(request_id, m_tokenized_chat_history, config, block_size, enable_prefix_caching);
sequence_group->update_processed_tokens_num(m_tokenized_chat_history.get_shape().at(1) - 1);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
Expand All @@ -282,7 +291,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

Sampler sampler = Sampler(m_tokenizer);
// we can't properly refer to history in case of chat scenario with beam search, so reset_kv_state and use the whole history for each new propmt
auto input_tokens = (is_chat_conversation && config.is_beam_search()) ? m_tokenized_chat_history.input_ids : input_ids;
auto input_tokens = (is_chat_conversation && config.is_beam_search()) ? m_tokenized_chat_history : input_ids;
result = ov::genai::get_lm_encoded_results(m_model_runner, input_tokens, concatenated_attention_mask, streamer_ptr, sampler, requests, position_ids, std::nullopt);

if (!is_chat_conversation || config.is_beam_search()) {
Expand All @@ -291,6 +300,15 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

if (is_chat_conversation) {
m_is_cache_empty = false;

// remove eos token, if it is at the end
auto last_token = result.tokens[0].back() == config.eos_token_id ? result.tokens[0].size() - 1 : result.tokens[0].size();
ov::Tensor new_tokenized_chat_history = ov::Tensor{ov::element::i64, {batch_size, m_tokenized_chat_history.get_shape().at(1) + last_token}};
auto start_chat_hst = m_tokenized_chat_history.data<int64_t>();
std::copy(start_chat_hst, start_chat_hst + m_tokenized_chat_history.get_size(), new_tokenized_chat_history.data<int64_t>());
std::copy(result.tokens[0].begin(), result.tokens[0].begin() + last_token, new_tokenized_chat_history.data<int64_t>() + m_tokenized_chat_history.get_size());

m_tokenized_chat_history = new_tokenized_chat_history;
}

auto stop_time = std::chrono::steady_clock::now();
Expand All @@ -311,6 +329,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
m_is_cache_empty = true;
m_history = {};
m_templated_chat_history = "";
m_tokenized_chat_history = ov::Tensor(ov::element::i64, {0, 0});
}
if (system_message.empty())
return;
Expand All @@ -328,6 +347,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
m_is_cache_empty = true;
m_history.clear();
m_templated_chat_history.clear();
m_tokenized_chat_history = ov::Tensor(ov::element::i64, {0, 0});
}
}
};
Expand Down
24 changes: 4 additions & 20 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,6 @@
namespace ov {
namespace genai {

void reset_all_inputs_to_empty_tensors(ov::InferRequest& request, bool is_vlm = false) {
if (!is_vlm)
request.set_tensor("input_ids", ov::Tensor(ov::element::i64, {0, 0}));

request.set_tensor("beam_idx", ov::Tensor(ov::element::i32, {0}));
if (request.get_compiled_model().inputs().size() == 4)
request.set_tensor("position_ids", ov::Tensor(ov::element::i64, {0, 0}));
}

void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask) {
const size_t batch_size = attention_mask.get_shape().at(0);
const size_t sequence_length = attention_mask.get_shape().at(1);
Expand Down Expand Up @@ -173,15 +164,11 @@ EncodedResults get_lm_encoded_results(

if (m_embedding.has_value()) {
const ov::Tensor& embed_prompt_tensor = (*m_embedding).infer(new_input_ids);

m_llm.get_tensor("inputs_embeds").set_shape(embed_prompt_tensor.get_shape());
m_llm.set_tensor("inputs_embeds", embed_prompt_tensor);
} else {
m_llm.get_tensor("input_ids").set_shape({total_num_tokens, 1});
m_llm.set_tensor("input_ids", new_input_ids);
}

m_llm.get_tensor("beam_idx").set_shape({ total_num_tokens });
m_llm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()});

update_attention_mask_with_beams(m_llm.get_tensor("attention_mask"), next_beams);
Expand Down Expand Up @@ -221,18 +208,15 @@ EncodedResults get_lm_encoded_results(
streamer_ptr->end();
}

reset_all_inputs_to_empty_tensors(m_llm, m_embedding.has_value());

for (size_t i = 0; i < sequence_groups.size(); i++) {
auto request = sequence_groups[i];
for (auto& sequence_group : sequence_groups) {
// sequences is sorted by cumulative_log_prob with length_penalty
auto outputs = request->get_finished_sequences();
auto outputs = sequence_group->get_finished_sequences();

auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, outputs.size());
auto num_outputs = std::min(sequence_group->get_sampling_parameters().num_return_sequences, outputs.size());
for (size_t output_idx = 0; output_idx < num_outputs; ++output_idx) {
const auto& output = outputs[output_idx];
results.tokens.push_back(std::move(output->get_generated_ids()));
results.scores.push_back(output->get_cumulative_score_with_length_penalty(request->get_sampling_parameters()));
results.scores.push_back(output->get_cumulative_score_with_length_penalty(sequence_group->get_sampling_parameters()));
}
}

Expand Down

0 comments on commit 11ac0a2

Please sign in to comment.