diff --git a/src/cpp/src/group_beam_searcher.cpp b/src/cpp/src/group_beam_searcher.cpp deleted file mode 100644 index a0262c0dc8..0000000000 --- a/src/cpp/src/group_beam_searcher.cpp +++ /dev/null @@ -1,455 +0,0 @@ -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include - -#include "openvino/genai/llm_pipeline.hpp" -#include "utils.hpp" -#include "lm_encoding.hpp" - -namespace { - -// Modified Knuth–Morris–Pratt algorithm which returns tokens following after every needle occurrence in haystack -std::vector kmp_search(const std::vector& haystack, const std::vector& needle) { - if (needle.empty()) { // no_repeat_ngram_size == 1, ban every token - return {haystack.begin(), haystack.end()}; - } - std::vector partial_match_table(needle.size() + 1, -1); - int cnd = 0; - for (size_t pos = 1; pos < needle.size(); ++pos) { - if (needle.at(pos) == needle.at(size_t(cnd))) { - partial_match_table.at(pos) = partial_match_table.at(size_t(cnd)); - } else { - partial_match_table.at(pos) = cnd; - while (cnd >= 0 && needle.at(pos) != needle.at(size_t(cnd))) { - cnd = partial_match_table.at(size_t(cnd)); - } - } - ++cnd; - } - partial_match_table.back() = cnd; - std::vector res; - size_t haystack_id = 0; - int needle_id = 0; - while (haystack_id < haystack.size() - 1) { - if (needle.at(size_t(needle_id)) == haystack.at(haystack_id)) { - ++haystack_id; - ++needle_id; - if (needle_id == int(needle.size())) { - res.push_back(haystack.at(haystack_id)); - needle_id = partial_match_table.at(size_t(needle_id)); - } - } else { - needle_id = partial_match_table.at(size_t(needle_id)); - if (needle_id < 0) { - ++haystack_id; - ++needle_id; - } - } - } - return res; -} - -struct Token { - float log_prob; - int64_t idx; -}; - -std::vector log_softmax(const ov::Tensor& logits, const size_t batch_idx) { - if (logits.get_shape().at(0) <= batch_idx) { - throw std::runtime_error("logits batch size doesn't match the number of beams"); - } - size_t vocab_size = logits.get_shape().back(); - size_t batch_offset = batch_idx * logits.get_shape().at(1) * vocab_size; - size_t sequence_offset = (logits.get_shape().at(1) - 1) * vocab_size; - const float* beam_logits = logits.data() + batch_offset + sequence_offset; - float max_logit = *std::max_element(beam_logits, beam_logits + vocab_size); - float log_sum = std::log( - std::accumulate(beam_logits, beam_logits + vocab_size, 0.0f, [max_logit](float accumulated, float to_add) { - return accumulated + std::exp(to_add - max_logit); - })); - std::vector tokens; - tokens.reserve(vocab_size); - for (size_t idx = 0; idx < vocab_size; ++idx) { - tokens.push_back({beam_logits[idx] - max_logit - log_sum, int64_t(idx)}); - } - return tokens; -} - -struct Beam { - float score = -std::numeric_limits::infinity(); // The bigger, the better - std::vector tokens; - size_t global_beam_idx = 0; -}; - -bool greater(const Beam& left, const Beam& right) { - return left.score > right.score; -} - -struct Parameters { - std::vector> prompts; - int64_t eos_token_id; - size_t n_groups = 3; - size_t group_size = 5; - float diversity_penalty = 1.0; - size_t max_new_tokens = 20; - ov::genai::StopCriteria stop_criteria = ov::genai::StopCriteria::HEURISTIC; - float length_penalty = 1.0; - size_t no_repeat_ngram_size = std::numeric_limits::max(); - - std::function early_finish = [](const Beam&) { - return false; - }; -}; - -struct Group { - std::vector ongoing; // Best beams in front - std::vector min_heap; // The worst of the best completed beams is the first - bool done = false; - - void finish(Beam&& beam, const Parameters& parameters) { - beam.score /= std::pow(float(beam.tokens.size()), parameters.length_penalty); - - min_heap.push_back(std::move(beam)); - std::push_heap(min_heap.begin(), min_heap.end(), greater); - if (min_heap.size() > parameters.group_size) { - std::pop_heap(min_heap.begin(), min_heap.end(), greater); - min_heap.pop_back(); - } - } - void is_done(const Parameters& parameters) { - if (min_heap.size() < parameters.group_size) { - return; - } - size_t cur_len = ongoing.front().tokens.size(); - float best_sum_logprobs = ongoing.front().score; - float worst_score = min_heap.front().score; - switch (parameters.stop_criteria) { - case ov::genai::StopCriteria::EARLY: - done = true; - return; - case ov::genai::StopCriteria::HEURISTIC: { - float highest_attainable_score = best_sum_logprobs / std::pow(float(cur_len), parameters.length_penalty); - done = worst_score >= highest_attainable_score; - return; - } - case ov::genai::StopCriteria::NEVER: { - size_t length = parameters.length_penalty > 0.0 ? parameters.max_new_tokens : cur_len; - float highest_attainable_score = best_sum_logprobs / std::pow(float(length), parameters.length_penalty); - done = worst_score >= highest_attainable_score; - return; - } - default: - throw std::runtime_error("Never reached"); - } - } -}; - -// GroupBeamSearcher processes logits prduced by a language model and accumulates beams using group beam search -// algorithm. select_next_tokens() returns token ids selected by the algorithm and corresponding beam ids. These values -// are used for next inference. select_next_tokens() returns empty, if all groups are completed -struct GroupBeamSearcher { - Parameters parameters; - std::vector> prompts_groups; - - GroupBeamSearcher(Parameters parameters) : parameters{parameters}, prompts_groups{parameters.prompts.size()} { - if (parameters.no_repeat_ngram_size == 0) { - throw std::runtime_error("no_repeat_ngram_size must be positive"); - } - for (std::vector& prompts_groups : prompts_groups) { - prompts_groups.resize(parameters.n_groups); - for (Group& group : prompts_groups) { - group.ongoing.resize(parameters.group_size); - group.ongoing.front().score = 0.0; - } - } - } - - std::pair, std::vector> select_next_tokens(const ov::Tensor& logits) { - std::vector next_tokens; - std::vector next_beams; - - const size_t promts_size = parameters.prompts.size(); - - next_tokens.reserve(promts_size * parameters.n_groups * parameters.group_size); - next_beams.reserve(promts_size * parameters.n_groups * parameters.group_size); - - size_t beam_count = 0; - size_t prompt_id = 0; - for (std::vector& groups : prompts_groups) { - for (Group& group : groups) { - if (group.done) { - continue; - } - for (Beam& beam : group.ongoing) { - // beam.tokens.empty() holds for the first select_next_tokens() call. - // Every beam is constructed from the single batch at first call - if (beam.tokens.empty()) { - beam.global_beam_idx = prompt_id; - } else { - beam.global_beam_idx = beam_count; - ++beam_count; - } - } - } - - prompt_id += 1; - } - - for (int prompt_id = 0; prompt_id < promts_size; prompt_id++) { - const std::vector prompt = parameters.prompts[prompt_id]; - std::vector& groups = prompts_groups[prompt_id]; - auto [prompt_next_tokens, prompt_next_beams] = select_prompt_next_tokens(logits, prompt, groups); - - next_tokens.insert(next_tokens.end(), prompt_next_tokens.begin(), prompt_next_tokens.end()); - next_beams.insert(next_beams.end(), prompt_next_beams.begin(), prompt_next_beams.end()); - } - - return {next_tokens, next_beams}; - } - - std::pair, std::vector> select_prompt_next_tokens(const ov::Tensor& logits, - const std::vector& prompt, - std::vector& groups) { - std::vector next_tokens; - std::vector next_beams; - next_tokens.reserve(parameters.n_groups * parameters.group_size); - next_beams.reserve(parameters.n_groups * parameters.group_size); - - for (auto group = groups.begin(); group != groups.end(); ++group) { - if (group->done) { - continue; - } - std::vector candidates; - candidates.reserve(parameters.group_size * 2 * parameters.group_size); - for (const Beam& beam : group->ongoing) { - std::vector tokens = log_softmax(logits, beam.global_beam_idx); - for (auto prev_group = groups.cbegin(); prev_group != group; ++prev_group) { - for (const Beam& prev_beam : prev_group->ongoing) { - if (prev_beam.tokens.size() > beam.tokens.size()) { - tokens.at(size_t(prev_beam.tokens.back())).log_prob -= parameters.diversity_penalty; - } - } - } - std::vector full_text{prompt}; - full_text.insert(full_text.end(), beam.tokens.begin(), beam.tokens.end()); - if (full_text.size() > 1 && full_text.size() >= parameters.no_repeat_ngram_size) { - auto tail_start = full_text.end() - ptrdiff_t(parameters.no_repeat_ngram_size) + 1; - for (int64_t banned_token : kmp_search(full_text, {tail_start, full_text.end()})) { - tokens.at(size_t(banned_token)).log_prob = -std::numeric_limits::infinity(); - } - } - std::sort(tokens.begin(), tokens.end(), [](Token left, Token right) { - return left.log_prob > right.log_prob; // Most probable tokens in front - }); - size_t add_count = 0; - for (Token token : tokens) { - Beam new_candidate = beam; - new_candidate.score += token.log_prob; - new_candidate.tokens.push_back(token.idx); - if (parameters.early_finish(new_candidate)) { - group->finish(std::move(new_candidate), parameters); - } else { - candidates.push_back(std::move(new_candidate)); - ++add_count; - if (add_count == 2 * parameters.group_size) { - break; - } - } - } - } - // Sample 2 * group_size highest score tokens to get at least 1 non EOS token per beam - if (candidates.size() < 2 * parameters.group_size) { - throw std::runtime_error("No beams left to search"); - } - auto to_sort = candidates.begin() + ptrdiff_t(2 * parameters.group_size); - std::partial_sort(candidates.begin(), to_sort, candidates.end(), greater); - group->ongoing.clear(); - for (size_t cand_idx = 0; cand_idx < candidates.size(); ++cand_idx) { - if (parameters.eos_token_id == candidates.at(cand_idx).tokens.back()) { - // If beam_token does not belong to top num_beams tokens, it should not be added - if (cand_idx >= parameters.group_size) { - continue; - } - group->finish(std::move(candidates.at(cand_idx)), parameters); - } else { - group->ongoing.push_back(std::move(candidates.at(cand_idx))); - if (group->ongoing.size() == parameters.group_size) { - break; - } - } - } - group->is_done(parameters); - if (!group->done) { - for (const Beam& beam : group->ongoing) { - next_tokens.push_back(beam.tokens.back()); - next_beams.push_back(int32_t(beam.global_beam_idx)); - } - } - } - return {next_tokens, next_beams}; - } -}; - -// Consume group_beam_searcher because beams are consumed -std::vector>> finalize(GroupBeamSearcher&& group_beam_searcher) { - std::vector>> finalized; - finalized.resize(group_beam_searcher.prompts_groups.size()); - - for (size_t prompt_id = 0; prompt_id < group_beam_searcher.prompts_groups.size(); prompt_id++) { - std::vector& groups = group_beam_searcher.prompts_groups.at(prompt_id); - finalized.at(prompt_id).reserve(groups.size()); - - for (Group& group : groups) { - if (!group.done) { - for (Beam& beam : group.ongoing) { - group.finish(std::move(beam), group_beam_searcher.parameters); - } - } - finalized.at(prompt_id).push_back(std::move(group.min_heap)); - } - } - - return finalized; -} - -void reset_all_inputs_to_empty_tensors(ov::InferRequest& request) { - request.set_tensor("input_ids", ov::Tensor(ov::element::i64, {0, 0})); - request.set_tensor("beam_idx", ov::Tensor(ov::element::i32, {0})); - if (request.get_compiled_model().inputs().size() == 4) - request.set_tensor("position_ids", ov::Tensor(ov::element::i64, {0, 0})); -} -} // namespace - -namespace ov { -namespace genai { - -std::pair beam_search(ov::InferRequest& lm, - ov::Tensor input_ids, - ov::Tensor attention_mask, - GenerationConfig config, - std::optional position_ids, - std::optional selected_beam_idx) { - OPENVINO_ASSERT(config.num_beams % config.num_beam_groups == 0, - "number of beams should be divisible by number of groups"); - - auto batch_size = input_ids.get_shape().at(0); - auto sequence_length = input_ids.get_shape().at(1); - - // Initialize beam search. - const int64_t* prompt_data = input_ids.data(); - std::vector> prompts; - prompts.reserve(batch_size); - for (size_t batch = 0; batch < batch_size; batch++) { - size_t batch_offset = batch * sequence_length; - const int64_t* prompt_start = prompt_data + batch_offset; - prompts.push_back(std::vector{prompt_start, prompt_start + sequence_length}); - } - - lm.set_tensor("input_ids", input_ids); - lm.set_tensor("attention_mask", attention_mask); - if (position_ids.has_value()) - lm.set_tensor("position_ids", *position_ids); - - ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size}); - auto beam_data = beam_idx.data(); - if (selected_beam_idx.has_value()) - beam_data[0] = *selected_beam_idx; - else - std::fill_n(beam_data, batch_size, 0); - lm.set_tensor("beam_idx", beam_idx); - - Parameters parameters{std::move(prompts)}; - parameters.max_new_tokens = config.get_max_new_tokens(sequence_length); - parameters.eos_token_id = config.eos_token_id; - parameters.n_groups = config.num_beam_groups; - parameters.group_size = config.num_beams / config.num_beam_groups; - parameters.diversity_penalty = config.diversity_penalty; - parameters.length_penalty = config.length_penalty; - parameters.stop_criteria = config.stop_criteria; - parameters.no_repeat_ngram_size = config.no_repeat_ngram_size; - GroupBeamSearcher group_beam_searcher{parameters}; - - std::vector next_tokens; - std::vector next_beams; - - // Reserve for performance counters. - std::vector new_token_times; - std::vector batch_sizes; - new_token_times.reserve(parameters.max_new_tokens); - batch_sizes.reserve(parameters.max_new_tokens); - - for (size_t length_count = 0; ; ++length_count) { - lm.infer(); - - std::tie(next_tokens, next_beams) = group_beam_searcher.select_next_tokens(lm.get_tensor("logits")); - new_token_times.emplace_back(std::chrono::steady_clock::now()); - batch_sizes.emplace_back(batch_size); - - if (next_tokens.empty() || length_count == parameters.max_new_tokens - 1) { - // Break the cycle before masks are extended in update_attention_mask_with_beams. - // If generation is continued, attention_mask length should be equal to KV cache size. - break; - } - - size_t running_batch_size = next_tokens.size(); - // Set pointers - lm.set_tensor("input_ids", ov::Tensor{ov::element::i64, {running_batch_size, 1}, next_tokens.data()}); - lm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {running_batch_size}, next_beams.data()}); - - // Set auxiliary inputs - update_attention_mask_with_beams(lm.get_tensor("attention_mask"), next_beams); - if (position_ids.has_value()) - update_position_ids(lm.get_tensor("position_ids"), lm.get_tensor("attention_mask")); - } - - reset_all_inputs_to_empty_tensors(lm); - - auto scores_comparator = [](Beam& left, Beam& right) { - return (left.score > right.score); - }; - - auto result = finalize(std::move(group_beam_searcher)); - ov::genai::EncodedResults results; - int32_t res_selected_beam_idx = 0; - results.scores.reserve(config.num_return_sequences * result.size()); - results.tokens.reserve(config.num_return_sequences * result.size()); - auto& raw_perf_counters = results.perf_metrics.raw_metrics; - raw_perf_counters.m_new_token_times = new_token_times; - raw_perf_counters.m_batch_sizes = batch_sizes; - - // align output with HF - for (size_t prompt_id = 0; prompt_id < result.size(); prompt_id++) { - auto prompt_group = result.at(prompt_id); - std::vector> plain_beams; - plain_beams.reserve(parameters.n_groups * parameters.group_size); - for (std::vector& group : prompt_group) { - for (Beam& beam : group) { - plain_beams.push_back(beam); - } - } - assert(config.num_return_sequences <= plain_beams.size()); - std::partial_sort( - plain_beams.begin(), - plain_beams.begin() + config.num_return_sequences, - plain_beams.end(), - scores_comparator - ); - res_selected_beam_idx = plain_beams.at(0).get().global_beam_idx; - for ( - auto beam = plain_beams.begin(); - beam != plain_beams.begin() + config.num_return_sequences; - ++beam - ) { - results.scores.push_back(beam->get().score); - results.tokens.push_back(std::move(beam->get().tokens)); - } - } - - return {results, res_selected_beam_idx}; -} - -} // namespace genai -} // namespace ov diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 62a72b1cbd..ca01e8e11c 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -21,18 +21,11 @@ #include "sampler.hpp" #include "lm_encoding.hpp" +#include "debug_utils.hpp" + namespace ov { namespace genai { -std::pair beam_search( - ov::InferRequest& lm, - ov::Tensor prompts, - ov::Tensor attention_mask, - GenerationConfig config, - std::optional position_ids, - std::optional selected_beam_idx -); - class StatefulLLMPipeline final : public LLMPipelineImplBase { public: ov::InferRequest m_model_runner; @@ -42,7 +35,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { std::optional m_selected_beam = std::nullopt; ChatHistory m_history; std::string m_templated_chat_history = {}; - TokenizedInputs m_tokenized_chat_history; + std::vector m_tokenized_chat_history; StatefulLLMPipeline( const ov::InferRequest& request, @@ -126,7 +119,6 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens); } m_templated_chat_history = new_templated_chat_history; - m_tokenized_chat_history = new_chat_tokens; // TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied } else { encoded_input = m_tokenizer.encode(prompt); @@ -222,24 +214,34 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { "(input_ids, attention_mask, position_ids, beam_idx) " "but you have '" + std::to_string(num_inputs) + "' inputs"); + if (is_chat_conversation) { + std::copy(input_ids.data(), input_ids.data() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history)); + } + ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()); + bool kv_history_available = m_selected_beam.has_value(); size_t kv_cache_len = 0; ov::Tensor concatenated_attention_mask; if (is_chat_conversation && !m_is_cache_empty) { - OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1"); - // If history is saved in KV cache, concatenate new attention_mask with the already existing. - // Between subsequent runs attention_mask should not be modified. - auto atten_mask_history = m_model_runner.get_tensor("attention_mask"); - auto prompt_len = attention_mask.get_shape()[1]; - kv_cache_len = atten_mask_history.get_shape()[1]; - - ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}}; - auto start_atten_hst = atten_mask_history.data() + kv_cache_len * (*m_selected_beam); - std::copy(start_atten_hst, start_atten_hst + kv_cache_len, - new_atten_mask.data()); - std::copy(attention_mask.data(), attention_mask.data() + prompt_len, - new_atten_mask.data() + kv_cache_len); - concatenated_attention_mask = new_atten_mask; + if (kv_history_available) { + OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1"); + // If history is saved in KV cache, concatenate new attention_mask with the already existing. + // Between subsequent runs attention_mask should not be modified. + auto atten_mask_history = m_model_runner.get_tensor("attention_mask"); + auto prompt_len = attention_mask.get_shape()[1]; + kv_cache_len = atten_mask_history.get_shape()[1]; + + ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}}; + auto start_atten_hst = atten_mask_history.data(); + std::copy(start_atten_hst, start_atten_hst + kv_cache_len, + new_atten_mask.data()); + std::copy(attention_mask.data(), attention_mask.data() + prompt_len, + new_atten_mask.data() + kv_cache_len); + concatenated_attention_mask = new_atten_mask; + } else { + attention_mask = ov::genai::utils::init_attention_mask(tokenized_chat_history); + concatenated_attention_mask = attention_mask; + } } else { concatenated_attention_mask = attention_mask; } @@ -247,8 +249,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { bool position_ids_available = (num_inputs == 4); std::optional position_ids = std::nullopt; if (position_ids_available) { - position_ids = ov::Tensor{ov::element::i64, input_ids.get_shape()}; - utils::initialize_position_ids(*position_ids, attention_mask, kv_cache_len); + if (is_chat_conversation && !kv_history_available) { + position_ids = ov::Tensor{ov::element::i64, tokenized_chat_history.get_shape()}; + } else { + position_ids = ov::Tensor{ov::element::i64, input_ids.get_shape()}; + } + utils::initialize_position_ids(*position_ids, attention_mask, kv_cache_len); } if(m_adapter_controller) { @@ -256,43 +262,50 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { } ov::genai::EncodedResults result; - if (config.is_beam_search() && is_chat_conversation) { - std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_ids, concatenated_attention_mask, - config, position_ids, m_selected_beam); - } else { - std::vector requests; - size_t block_size = 1; - bool enable_prefix_caching = false; - - config.stop_token_ids.insert(config.eos_token_id); - for (size_t request_id = 0; request_id < batch_size; request_id++) { - SequenceGroup::Ptr sequence_group; - if (is_chat_conversation && !m_is_cache_empty) { - sequence_group = std::make_shared(request_id, m_tokenized_chat_history.input_ids, config, block_size, enable_prefix_caching); - } else { - size_t seq_len = input_ids.get_shape().at(1); - size_t batch_offset = request_id * seq_len; - const int64_t* prompt_start = input_ids.data() + batch_offset; - std::vector tokenized_prompt(prompt_start, prompt_start + seq_len); - - sequence_group = std::make_shared(request_id, tokenized_prompt, config, block_size, enable_prefix_caching); - } + std::vector requests; + size_t block_size = 1; + bool enable_prefix_caching = false; + for (size_t request_id = 0; request_id < batch_size; request_id++) { + SequenceGroup::Ptr sequence_group; + if (is_chat_conversation) { + sequence_group = std::make_shared(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching); + } else { + size_t seq_len = input_ids.get_shape().at(1); + size_t batch_offset = request_id * seq_len; + const int64_t* prompt_start = input_ids.data() + batch_offset; + std::vector tokenized_prompt(prompt_start, prompt_start + seq_len); - sequence_group->set_sequence_group_ptr(sequence_group); - requests.push_back(sequence_group); + sequence_group = std::make_shared(request_id, tokenized_prompt, config, block_size, enable_prefix_caching); } - Sampler sampler = Sampler(m_tokenizer); - std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr, - sampler, requests, position_ids, std::nullopt, m_selected_beam); + sequence_group->set_sequence_group_ptr(sequence_group); + requests.push_back(sequence_group); + } + + Sampler sampler = Sampler(m_tokenizer); + // we can't properly refer to history in case of chat scenario with beam search, so reset_kv_state and use the whole history for each new propmt + auto input_tokens = input_ids; + if (is_chat_conversation && !kv_history_available) { + input_tokens = tokenized_chat_history; } + result = ov::genai::get_lm_encoded_results(m_model_runner, input_tokens, concatenated_attention_mask, streamer_ptr, sampler, requests, position_ids, std::nullopt); - if (!is_chat_conversation) { + m_selected_beam = 0; + if (!is_chat_conversation || config.is_beam_search()) { reset_kv_state(); m_selected_beam = std::nullopt; - } else { + } + + if (is_chat_conversation) { m_is_cache_empty = false; } + + if (is_chat_conversation) { + auto decoded_result = m_tokenizer.decode(result.tokens[0]); + auto answer = m_tokenizer.encode(decoded_result, ov::genai::add_special_tokens(false)).input_ids; + std::copy(answer.data(), answer.data() + answer.get_size(), std::back_inserter(m_tokenized_chat_history)); + } + auto stop_time = std::chrono::steady_clock::now(); // If is called without tokenization then that stat will not be reported. @@ -306,12 +319,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void start_chat(const std::string& system_message) override { is_chat_conversation = true; - m_selected_beam = std::nullopt; + m_selected_beam = std::nullopt; if (!m_is_cache_empty) { reset_kv_state(); m_is_cache_empty = true; m_history = {}; m_templated_chat_history = ""; + m_tokenized_chat_history = {}; } if (system_message.empty()) return; @@ -330,6 +344,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { m_is_cache_empty = true; m_history.clear(); m_templated_chat_history.clear(); + m_tokenized_chat_history = {}; } } }; diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index c76d9f7edf..15b497ca76 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -12,10 +12,10 @@ #include "lm_encoding.hpp" #include "openvino/genai/perf_metrics.hpp" -#include "debug_utils.hpp" - #include "utils.hpp" +#include "debug_utils.hpp" + namespace ov { namespace genai { @@ -51,7 +51,7 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector get_lm_encoded_results( +EncodedResults get_lm_encoded_results( ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, @@ -59,14 +59,8 @@ std::pair get_lm_encoded_results( Sampler& sampler, std::vector sequence_groups, std::optional position_ids, - std::optional m_embedding, - std::optional selected_beam_idx + std::optional m_embedding ) { - std::vector generations; - for (SequenceGroup::Ptr sequence_group : sequence_groups) { - generations.push_back(std::make_shared(sequence_group->get_generation_stream(), sequence_group->get_sampling_parameters())); - } - ov::Shape prompts_shape = input_ids.get_shape(); const size_t batch_size = prompts_shape[0]; @@ -88,10 +82,7 @@ std::pair get_lm_encoded_results( ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size}); auto beam_data = beam_idx.data(); - if (selected_beam_idx.has_value()) - beam_data[0] = *selected_beam_idx; - else - std::fill_n(beam_data, batch_size, 0); + std::fill_n(beam_data, batch_size, 0); m_llm.set_tensor("beam_idx", beam_idx); const auto infer_start = std::chrono::steady_clock::now(); @@ -161,7 +152,7 @@ std::pair get_lm_encoded_results( // apply strides to shift to a next sequence input_ids_data += num_scheduled_tokens; - // for different sequences iteration of beams started from 0, but we collect it to one input_ids# + // for different sequences iteration of beams started from 0, but we collect it to one input_ids next_beams.push_back(beam_idxs[sequence->get_id()] + beam_offets.at(sequence_group->get_request_id())); } } @@ -176,22 +167,26 @@ std::pair get_lm_encoded_results( if (m_embedding.has_value()) { const ov::Tensor& embed_prompt_tensor = (*m_embedding).infer(new_input_ids); - - m_llm.get_tensor("inputs_embeds").set_shape(embed_prompt_tensor.get_shape()); m_llm.set_tensor("inputs_embeds", embed_prompt_tensor); } else { - m_llm.get_tensor("input_ids").set_shape(new_input_ids.get_shape()); m_llm.set_tensor("input_ids", new_input_ids); } + m_llm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()}); + update_attention_mask_with_beams(m_llm.get_tensor("attention_mask"), next_beams); if (position_ids.has_value()) { update_position_ids(m_llm.get_tensor("position_ids"), m_llm.get_tensor("attention_mask")); } - m_llm.get_tensor("beam_idx").set_shape({ total_num_tokens }); - m_llm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()}); + if (streamer_ptr) { + // stream data from first sequence + int64_t out_token = sequence_groups.at(0).get()->operator[](0)->get_generated_ids().back(); + if (streamer_ptr->put(out_token)) { + break; + } + } const auto infer_start = std::chrono::steady_clock::now(); m_llm.infer(); @@ -202,14 +197,6 @@ std::pair get_lm_encoded_results( raw_perf_counters.m_new_token_times.emplace_back(infer_end); raw_perf_counters.m_batch_sizes.emplace_back(batch_size); - if (streamer_ptr) { - // stream data from first sequence - int64_t out_token = sequence_groups.at(0).get()->operator[](0)->get_generated_ids().back(); - if (streamer_ptr->put(out_token)) { - break; - } - } - sampler_output = sampler.sample(active_sequence_groups, m_llm.get_tensor("logits")); active_sequence_groups.erase(std::remove_if(active_sequence_groups.begin(), @@ -223,26 +210,20 @@ std::pair get_lm_encoded_results( streamer_ptr->put(out_token); streamer_ptr->end(); } - - size_t next_selected_beam = 0; - for (size_t i = 0; i < sequence_groups.size(); i++) { - auto request = sequence_groups[i]; - auto generation_outputs = generations[i]->read_all(); - - std::sort(generation_outputs.begin(), generation_outputs.end(), [] (const GenerationOutput& r1, const GenerationOutput& r2) { - return r1.score > r2.score; - }); - - auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, generation_outputs.size()); - for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) { - const auto& generation_output = generation_outputs[generation_output_idx]; - results.tokens.push_back(std::move(generation_output.generated_ids)); - results.scores.push_back(generation_output.score); + + for (auto& sequence_group : sequence_groups) { + // sequences is sorted by cumulative_log_prob with length_penalty + auto outputs = sequence_group->get_finished_sequences(); + + auto num_outputs = std::min(sequence_group->get_sampling_parameters().num_return_sequences, outputs.size()); + for (size_t output_idx = 0; output_idx < num_outputs; ++output_idx) { + const auto& output = outputs[output_idx]; + results.tokens.push_back(output->get_generated_ids()); + results.scores.push_back(output->get_cumulative_score_with_length_penalty(sequence_group->get_sampling_parameters())); } - // next_selected_beam = sampler.last_selected_beam(request); } - return {results, next_selected_beam}; + return results; } } // namespace genai diff --git a/src/cpp/src/lm_encoding.hpp b/src/cpp/src/lm_encoding.hpp index fa6692ede0..f95eae1436 100644 --- a/src/cpp/src/lm_encoding.hpp +++ b/src/cpp/src/lm_encoding.hpp @@ -8,9 +8,9 @@ namespace ov { namespace genai { -std::pair get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, +EncodedResults get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, const std::shared_ptr& streamer_ptr, Sampler& sampler, std::vector sequence_groups, - std::optional position_ids, std::optional m_embedding, std::optional selected_beam_idx); + std::optional position_ids, std::optional m_embedding); void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector next_beams); diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index c5be82f0f2..7315beee95 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -167,7 +167,7 @@ class Sequence { m_generated_log_probs[idx] = log_prob; } - float get_beam_search_score(const ov::genai::GenerationConfig& sampling_params) const { + float get_cumulative_score_with_length_penalty(const ov::genai::GenerationConfig& sampling_params) const { float cumulative_log_prob = get_cumulative_log_probs(), current_length = get_generated_len(); float score = cumulative_log_prob / std::pow(current_length, sampling_params.length_penalty); return score; @@ -339,7 +339,7 @@ class SequenceGroup { // do we need to sort sequences here or sampler can handle it for us? std::sort(finished_seqs.begin(), finished_seqs.end(), [=] (Sequence::CPtr s1, Sequence::CPtr s2) { - return s1->get_beam_search_score(m_sampling_params) > s2->get_beam_search_score(m_sampling_params); + return s1->get_cumulative_score_with_length_penalty(m_sampling_params) > s2->get_cumulative_score_with_length_penalty(m_sampling_params); }); return finished_seqs; @@ -586,7 +586,7 @@ class SequenceGroup { output.generated_ids.insert(output.generated_ids.begin(), m_prompt_ids.begin(), m_prompt_ids.end()); output.generated_log_probs.insert(output.generated_log_probs.begin(), m_prompt_log_probs.begin(), m_prompt_log_probs.end()); } - output.score = m_sampling_params.is_beam_search() ? sequence->get_beam_search_score(m_sampling_params) : sequence->get_cumulative_log_probs(); + output.score = m_sampling_params.is_beam_search() ? sequence->get_cumulative_score_with_length_penalty(m_sampling_params) : sequence->get_cumulative_log_probs(); output.finish_reason = sequence->get_finish_reason(); outputs.emplace(sequence->get_grouped_id(), output); } diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index 9ece0ff754..625a9f24e3 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -93,8 +93,6 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { ov::Tensor inputs_embeds = m_inputs_embedder->get_inputs_embeds(prompt, rgbs); - Sampler sampler = Sampler(m_tokenizer); - std::vector requests; size_t request_id = 0; size_t block_size = 1; // not used @@ -131,10 +129,10 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { ov::Tensor position_ids = ov::Tensor{ov::element::i64, { 1, inputs_embeds.get_shape()[1] }}; std::iota(position_ids.data(), position_ids.data() + position_ids.get_size(), history_size); - ov::genai::EncodedResults encoded_result; - int32_t m_selected_beam = 0; - std::tie(encoded_result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, sampler, requests, - position_ids, m_embedding, std::nullopt); + Sampler sampler = Sampler(m_tokenizer); + + ov::genai::EncodedResults encoded_result = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, sampler, requests, + position_ids, m_embedding); DecodedResults decoded; for (size_t idx = 0; idx < encoded_result.tokens.size(); ++idx) {