Skip to content

Commit

Permalink
fix: use num tokens to decode to replace spare latency
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzhuofu committed Dec 8, 2024
1 parent 4c1b2ce commit 9fb8885
Showing 1 changed file with 138 additions and 11 deletions.
149 changes: 138 additions & 11 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,11 @@ BatchConfig RequestManager::prepare_next_batch() {
}
break;
case DECODING:
return prepare_decoding_batch();
if get_fcfs_slo () {
return prepare_decoding_batch_fcfs_slo();
} else {
return prepare_decoding_batch();
}
case SSM_SPEC:
if (current_ssm_step == 0) {
return prepare_first_spec_batch_config();
Expand Down Expand Up @@ -1399,6 +1403,121 @@ BatchConfig RequestManager::prepare_decoding_batch() {
profiling.llm_step_start = Realm::Clock::current_time_in_microseconds();
return bc;
}

BatchConfig RequestManager::prepare_decoding_batch_fcfs_slo() {
// This function is called when the request_manager_status is DECODING. It
// fills the last token of each request in the current batch to the
// BatchConfig for the LLM to decode.
if (verbose) {
std::cout << "\n############### prepare_decoding_batch_fcfs_slo "
"##############\n";
}

BatchConfig bc;
bc.inference_mode = InferenceMode::INC_DECODING_MODE;
bc.prompt_phase = false;

// Check if there are any requests whose SLO is in the fastest category
std::copy(std::begin(request_available),
std::end(request_available),
std::begin(bc.request_available));
bc.num_available_requests = num_available_requests;
bool has_fastest_slo = false;
for (int request_index = 0; request_index < get_max_requests_per_batch();
request_index++) {

if (!request_available[request_index]) {
continue;
}
Request &request = all_requests[guid_of_requests[request_index]];
assert(request.status == Request::RUNNING);

if (request.get_slo_ratio() <= 1.0) {
has_fastest_slo = true;
break;
}
}

// If there are requests with the fastest SLO, we limit the number of requests
// to be decoded in this batch to 8
if (has_fastest_slo) {
int num_fastest_slo_requests = 0;
for (int request_index = 0; request_index < get_max_requests_per_batch();
request_index++) {
if (!request_available[request_index]) {
continue;
}
Request &request = all_requests[guid_of_requests[request_index]];
assert(request.status == Request::RUNNING);

if (request.get_slo_ratio() <= 1.0) {
num_fastest_slo_requests++;
}
}

if (num_fastest_slo_requests > 8) {
int num_remaining_requests = 0;
std::vector<std::pair<long long, int>> start_time_and_request_index;
for (int request_index = 0; request_index < get_max_requests_per_batch();
request_index++) {
if (!request_available[request_index]) {
continue;
}
Request &request = all_requests[guid_of_requests[request_index]];
assert(request.status == Request::RUNNING);

if (request.get_slo_ratio() > 1.0) {
request_available[request_index] = false;
num_available_requests--;
num_remaining_requests++;
}
}
bc.num_available_requests -= num_remaining_requests;
}
}

for (int request_index = 0; request_index < get_max_requests_per_batch();
request_index++) {
if (!request_available[request_index]) {
continue;
}
Request &request = all_requests[guid_of_requests[request_index]];
assert(request.status == Request::RUNNING);

// Per Request Info
bc.requestsInfo[request_index].first_token_index_in_request =
request.llm_cache_size;
bc.requestsInfo[request_index].first_token_offset_in_batch = bc.num_tokens;
bc.requestsInfo[request_index].num_tokens_in_batch = 1;

// Copy the streaming cache info
bc.streamingCacheInfo[request_index] = request.streaming_cache_info;

request.first_token_offset_in_batch = bc.num_tokens;
request.num_tokens_in_batch = 1;

// Per Token Info
bc.tokensInfo[bc.num_tokens].request_index = request_index;
bc.tokensInfo[bc.num_tokens].abs_index_in_request = request.llm_cache_size;
bc.tokensInfo[bc.num_tokens].abs_depth_in_request = request.llm_cache_size;
bc.tokensInfo[bc.num_tokens].token_id = request.tokens.back();

bc.num_tokens++;

if (profiling_requests[request.guid].llm_decoding_steps == 0) {
profiling_requests[request.guid].start_decoding_time =
Realm::Clock::current_time_in_microseconds();
}
}

if (verbose) {
std::cout << "prepare_decoding_batch_fcfs_slo NEW batchconfig:"
<< std::endl;
bc.print();
}
profiling.llm_step_start = Realm::Clock::current_time_in_microseconds();
return bc;
}
/* ----- Speculative Inference Specific functions ----- */

/***** Request Init Phase *****/
Expand Down Expand Up @@ -3146,8 +3265,8 @@ void RequestManager::prune_token_tree() {
int budget = get_max_tokens_per_batch() - num_available_requests;
assert(budget >= 0);

std::vector<std::pair<double, int>> spare_latency_2_request_index;
spare_latency_2_request_index.reserve(get_max_requests_per_batch());
std::vector<std::pair<double, int>> num_tokens_to_decode_2_request_index;
num_tokens_to_decode_2_request_index.reserve(get_max_requests_per_batch());
for (int request_index = 0; request_index < get_max_requests_per_batch();
++request_index) {
if (!request_available[request_index]) {
Expand All @@ -3159,22 +3278,30 @@ void RequestManager::prune_token_tree() {
if (request.get_slo_ratio() > 999) { // infinity
continue;
}
double spare_latency =
get_request_expected_latency(request) - request.decode_latency_ms;
spare_latency_2_request_index.push_back(
std::make_pair(spare_latency, request_index));
double num_tokens_to_decode_per_step =
(ssm_spec_latency_ms + llm_verify_latency_ms) * correction_factor /
get_slo_constraint(request);
double expected_num_tokens_decoded =
request.decode_latency_ms / get_slo_constraint(request);
double num_tokens_to_decode =
max(1.0,
num_tokens_to_decode_per_step + expected_num_tokens_decoded -
request.decode_length());
num_tokens_to_decode_2_request_index.push_back(
std::make_pair(num_tokens_to_decode, request_index));
}

// Sort the requests by spare latency in ascending order
std::sort(spare_latency_2_request_index.begin(),
spare_latency_2_request_index.end(),
std::sort(num_tokens_to_decode_2_request_index.begin(),
num_tokens_to_decode_2_request_index.end(),
std::less<std::pair<double, int>>());

for (auto const &spare_latency_request_index_pair :
spare_latency_2_request_index) {
num_tokens_to_decode_2_request_index) {
int request_index = spare_latency_request_index_pair.second;
RequestGuid guid = guid_of_requests[request_index];
add_tokens_toward_slo(guid, budget, spare_latency_2_request_index.size());
add_tokens_toward_slo(
guid, budget, num_tokens_to_decode_2_request_index.size());
}

assert(budget >= 0);
Expand Down

0 comments on commit 9fb8885

Please sign in to comment.