Skip to content

Commit

Permalink
Dirty version
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Oct 1, 2024
1 parent ae67be5 commit 1d7d31d
Show file tree
Hide file tree
Showing 8 changed files with 519 additions and 121 deletions.
272 changes: 186 additions & 86 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,54 +14,36 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
const Tokenizer& tokenizer,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config,
bool is_validation_mode_enabled) {
const ov::AnyMap& plugin_config) {
m_tokenizer = tokenizer;
m_is_validation_mode_enabled = is_validation_mode_enabled;
ov::Core core;

ov::Core core;
// The model can be compiled for GPU as well
std::shared_ptr<ov::Model> model = core.read_model(models_path + "/openvino_model.xml");

DeviceConfig device_config(core, scheduler_config, device, plugin_config);

bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction;
apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control);
init(model, scheduler_config, plugin_config, device_config, core);
}

ov::InferRequest infer_request = core.compile_model(model, device_config.get_device(), plugin_config).create_infer_request();

// setup KV caches
m_cache_manager = std::make_shared<CacheManager>(device_config, core);
for (size_t decoder_layer_id = 0; decoder_layer_id < device_config.get_num_layers(); ++decoder_layer_id) {
infer_request.set_tensor(std::string("key_cache.") + std::to_string(decoder_layer_id), m_cache_manager->get_key_cache(decoder_layer_id));
infer_request.set_tensor(std::string("value_cache.") + std::to_string(decoder_layer_id), m_cache_manager->get_value_cache(decoder_layer_id));
}

SchedulerConfig updated_config = scheduler_config;
// update KV number in scheduler config
if (scheduler_config.num_kv_blocks != device_config.get_num_kv_blocks()) {
updated_config.num_kv_blocks = device_config.get_num_kv_blocks();
}

bool can_use_partial_preemption = true;
if (device_config.get_device().find("GPU") != std::string::npos && !updated_config.dynamic_split_fuse) {
// in case of executing a `vLLM-like` pipeline, it's better not to use partial eviction on the GPU,
// as it may lead to performance slowdown
can_use_partial_preemption = false;
}

m_scheduler = std::make_shared<Scheduler>(updated_config, device_config.get_num_layers(), can_use_partial_preemption);
// and finally create model runner
bool is_use_cache_eviction = m_scheduler->get_config().use_cache_eviction;
if (is_use_cache_eviction) {
m_model_runner = std::make_shared<ModelRunner>(infer_request, updated_config, device_config.get_num_layers(), true);
} else {
m_model_runner = std::make_shared<ModelRunner>(infer_request, updated_config, device_config.get_num_layers());
}
m_sampler = std::make_shared<Sampler>(m_tokenizer);
m_sampler->set_seed(m_generation_config.rng_seed);
ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
ov::Core& core,
const std::shared_ptr<ov::Model>& model,
const Tokenizer& tokenizer,
const DeviceConfig& device_config,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config,
bool is_validation_mode_enabled) {
m_is_validation_mode_enabled = is_validation_mode_enabled;
init(model, scheduler_config, plugin_config, device_config, core);
}

// read default generation config
void ContinuousBatchingPipeline::ContinuousBatchingImpl::_pull_awaiting_requests() {
std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex};
m_requests.insert(m_requests.end(), m_awaiting_requests.begin(), m_awaiting_requests.end());
m_awaiting_requests.clear();
}

GenerationHandle
Expand Down Expand Up @@ -107,11 +89,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
step_timer.start();

// Pull awaiting requests
{
std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex};
m_requests.insert(m_requests.end(), m_awaiting_requests.begin(), m_awaiting_requests.end());
m_awaiting_requests.clear();
}
_pull_awaiting_requests();

m_pipeline_metrics.requests = m_requests.size();
Scheduler::Output scheduler_output;
Expand Down Expand Up @@ -294,49 +272,49 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
return results;
}

std::vector<GenerationResult>
ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<std::string>& prompts,
std::vector<ov::genai::GenerationConfig> sampling_params,
const StreamerVariant& streamer) {
std::vector<ov::Tensor> input_ids;
static ManualTimer timer("tokenize");
if (m_is_chat_conversation) {
OPENVINO_ASSERT(1 == prompts.size(), "Can't chat with multiple prompts");
m_history.push_back({{"role", "user"}, {"content", prompts.at(0)}});
constexpr bool add_generation_prompt = true;
std::string history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
timer.start();
input_ids.push_back(m_tokenizer.encode(history).input_ids);
timer.end();
} else {
input_ids.reserve(prompts.size());
for (const std::string& prompt : prompts) {
timer.start();
input_ids.push_back(m_tokenizer.encode(prompt).input_ids);
timer.end();
}
}
std::vector<EncodedGenerationResult> encoded = generate(input_ids, sampling_params, streamer);
std::vector<GenerationResult> decoded;
decoded.reserve(encoded.size());
for (EncodedGenerationResult& res : encoded) {
std::vector<std::string> generated;
generated.reserve(res.m_generation_ids.size());
for (size_t idx = 0; idx < res.m_generation_ids.size(); ++idx) {
generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx)));
if (m_is_chat_conversation && 0 == idx) {
m_history.push_back({{"role", "assistant"}, {"content", generated.back()}});
}
}
decoded.push_back(GenerationResult{
res.m_request_id,
std::move(generated),
std::move(res.m_scores),
res.m_status
});
}
return decoded;
}
// std::vector<GenerationResult>
// ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<std::string>& prompts,
// std::vector<ov::genai::GenerationConfig> sampling_params,
// const StreamerVariant& streamer) {
// std::vector<ov::Tensor> input_ids;
// static ManualTimer timer("tokenize");
// if (m_is_chat_conversation) {
// OPENVINO_ASSERT(1 == prompts.size(), "Can't chat with multiple prompts");
// m_history.push_back({{"role", "user"}, {"content", prompts.at(0)}});
// constexpr bool add_generation_prompt = true;
// std::string history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
// timer.start();
// input_ids.push_back(m_tokenizer.encode(history).input_ids);
// timer.end();
// } else {
// input_ids.reserve(prompts.size());
// for (const std::string& prompt : prompts) {
// timer.start();
// input_ids.push_back(m_tokenizer.encode(prompt).input_ids);
// timer.end();
// }
// }
// std::vector<EncodedGenerationResult> encoded = generate(input_ids, sampling_params, streamer);
// std::vector<GenerationResult> decoded;
// decoded.reserve(encoded.size());
// for (EncodedGenerationResult& res : encoded) {
// std::vector<std::string> generated;
// generated.reserve(res.m_generation_ids.size());
// for (size_t idx = 0; idx < res.m_generation_ids.size(); ++idx) {
// generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx)));
// if (m_is_chat_conversation && 0 == idx) {
// m_history.push_back({{"role", "assistant"}, {"content", generated.back()}});
// }
// }
// decoded.push_back(GenerationResult{
// res.m_request_id,
// std::move(generated),
// std::move(res.m_scores),
// res.m_status
// });
// }
// return decoded;
// }

void ContinuousBatchingPipeline::ContinuousBatchingImpl::_free_non_running_requests() {
std::vector<SequenceGroup::Ptr>::iterator requests_iterator = m_requests.begin();
Expand Down Expand Up @@ -413,4 +391,126 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_block
seq_group_ptr->register_token_eviction(num_blocks_evicted * sched_config.block_size);
}
}

void ContinuousBatchingPipeline::ContinuousBatchingImpl::finish_request(int64_t request_id) {
if (request_id == -1) {
while (!m_requests.empty()) {
const auto& request = *m_requests.rbegin();
for (const auto& sequence : request->get_sequences()) {
m_scheduler->free_sequence(sequence->get_id());
}
m_sampler->clear_beam_search_info(request->get_request_id());
m_requests.pop_back();
}
} else {
for (size_t i = 0; i < m_requests.size(); ++i) {
auto& request = m_requests[i];
if (request->get_request_id() != request_id) {
continue;
}
for (const auto& sequence : request->get_sequences()) {
m_scheduler->free_sequence(sequence->get_id());
}
m_sampler->clear_beam_search_info(request->get_request_id());
m_requests.erase(m_requests.begin() + i);
break;
}
}
}

std::vector<ContinuousBatchingPipeline::ContinuousBatchingImpl::GeneratedSequence>
ContinuousBatchingPipeline::ContinuousBatchingImpl::get_generated_sequences() {
_pull_awaiting_requests();
std::vector<ContinuousBatchingPipeline::ContinuousBatchingImpl::GeneratedSequence> result;
for (const auto& request : m_requests) {
const auto request_id = request->get_request_id();
for (const auto& sequence : request->get_sequences()) {
auto generated_ids = sequence->get_generated_ids();
auto log_probs = sequence->get_generated_log_probs();
result.emplace_back(request_id, sequence->get_grouped_id(), generated_ids, log_probs);
}
}
return result;
}

ContinuousBatchingPipeline::ContinuousBatchingImpl::UpdateSeqResult
ContinuousBatchingPipeline::ContinuousBatchingImpl::update_generated_sequence(
const ContinuousBatchingPipeline::ContinuousBatchingImpl::GeneratedSequence& candidate_sequence) {
_pull_awaiting_requests();
bool is_empty_generated_tokens = false;
for (auto& request : m_requests) {
if (candidate_sequence.request_id == request->get_request_id()) {
bool is_seq_exists = false;
// todo: iefode: multiseq
size_t to_remove_tokens = 0, to_insert_tokens = 0;
for (auto& sequence : request->get_sequences()) {
if (candidate_sequence.sequence_id == sequence->get_grouped_id()) {
is_seq_exists = true;
auto present_ids = sequence->get_generated_ids();
const auto& candidate_ids = candidate_sequence.token_ids;

// remove extra tokens from sequence
{
auto token_idx = std::min(present_ids.size(), candidate_ids.size());
if (token_idx) {
while (token_idx-- > 0) {
if (present_ids[token_idx] == candidate_ids[token_idx]) {
break;
}
}
to_remove_tokens = present_ids.size() - (token_idx + 1);
if (to_remove_tokens > 0) {
const auto gen_ids_before = sequence->get_generated_ids();
sequence->remove_last_tokens(to_remove_tokens);
present_ids = sequence->get_generated_ids();
const size_t gen_len_before = gen_ids_before.size(),
gen_len_after = present_ids.size();
if (gen_len_after == 0) {
is_empty_generated_tokens = true;
}
OPENVINO_ASSERT(gen_len_after < gen_len_before);
for (size_t i = gen_len_after; i < gen_len_before; ++i) {
// todo
// m_sampler->update_logit_processor(request->get_request_id(), gen_ids_before[i]);
}
}
}
}
// insert new tokens to sequence
{
OPENVINO_ASSERT(candidate_ids.size() >= present_ids.size());
const auto& candidate_log_probs = candidate_sequence.log_probs;
const size_t start_id = std::min(present_ids.size(), candidate_ids.size()),
stop_id = std::max(present_ids.size(), candidate_ids.size());
to_insert_tokens = stop_id - start_id;
for (size_t i = start_id; i < stop_id; ++i) {
sequence->append_token(candidate_ids[i], i < candidate_log_probs.size() ? candidate_log_probs[i] : 0.f);
}
}
}
break;
}
if (!is_seq_exists) {
Sequence::Ptr new_sequence(new Sequence(candidate_sequence.sequence_id));
const auto& generated_tokens = candidate_sequence.token_ids;
const auto& generated_log_probs = candidate_sequence.log_probs;
for (size_t i = 0; i < generated_tokens.size(); ++i) {
new_sequence->append_token(generated_tokens[i], generated_log_probs[i]);
}
request->add_sequence(new_sequence);
}
if (!is_empty_generated_tokens) {
if (to_remove_tokens > 0) {
// request->decrease_processed_tokens(to_remove_tokens);
}
// to validate tokens/extend kv-cache before generation
// request->set_validation_len(to_insert_tokens);
} else if (to_remove_tokens > 0) {
request->update_processed_tokens_num(request->get_prompt_len());
}
return ContinuousBatchingPipeline::ContinuousBatchingImpl::UpdateSeqResult(to_insert_tokens, to_remove_tokens);
}
}
return {0, 0};
}
}
Loading

0 comments on commit 1d7d31d

Please sign in to comment.