Skip to content

Commit

Permalink
prompt超长时api server直接返回
Browse files Browse the repository at this point in the history
  • Loading branch information
ztxz16 committed Jul 23, 2024
1 parent bf8ba72 commit b3b3503
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
6 changes: 6 additions & 0 deletions include/models/basellm.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ using RuntimeResultBatch = std::function<void(int index, std::vector <std::strin
namespace fastllm {
using ChatMessages = std::vector <std::pair <std::string, std::string> >;

enum ResponseContextError {
ResponseContextErrorNone = 0, ResponseContextErrorPromptTooLong
};

struct ResponseContext {
bool isEnding = false;
std::vector <std::pair <Data, Data> > pastKeyValues;
Expand All @@ -30,6 +34,7 @@ namespace fastllm {
std::queue <std::vector <float>*> resultLogits;
GenerationConfig generationConfig;
LastTokensUnit tokens;
ResponseContextError error = ResponseContextErrorNone;

int preTokens = 0;
int curTokens = 0;
Expand Down Expand Up @@ -232,6 +237,7 @@ namespace fastllm {
std::string adapterName;

int tokensLimit = -1;
int promptLimit = -1;

PastKVCacheManager pastKVCacheManager;
bool saveHistoryChat = false;
Expand Down
17 changes: 15 additions & 2 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,13 @@ namespace fastllm {
}

model->tokensLimit = maxTotalLens;
int limit = maxTotalLens;
model->promptLimit = limit * 2 / 3;

if (model->verbose) {
printf("Fastllm KV Cache Limit: %f MB.\n", (double)kvCacheLimit / 1024 / 1024);
printf("Fastllm KV Cache Token limit: %d tokens.\n", maxTotalLens);
printf("Fastllm Prompt Token limit: %d tokens.\n", std::min(model->max_positions, model->promptLimit));
printf("Fastllm Batch limit: %d.\n", maxBatch);
}

Expand All @@ -538,7 +542,7 @@ namespace fastllm {
std::unique_lock<std::mutex> dictLocker(model->dictLocker);

int limit = maxTotalLens;
int promptLimit = limit * 2 / 3;
int promptLimit = model->promptLimit;

int lenSum = 0;
for (auto &it: model->responseContextDict.dicts) {
Expand Down Expand Up @@ -566,6 +570,11 @@ namespace fastllm {
if (!isPrompt && it.second->preTokens == 0) {
continue;
}
if (it.second->currentTokens.size() > promptLimit) {
it.second->isEnding = true;
it.second->error = ResponseContextErrorPromptTooLong;
continue;
}

int outputLimit = it.second->generationConfig.output_token_limit;
outputLimit = (outputLimit < 0 ? 128 : outputLimit);
Expand Down Expand Up @@ -794,7 +803,11 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
} else {
if (context->isEnding) {
responseContextDict.RemoveHandle(handleId);
return -1;
if (context->error == ResponseContextErrorNone) {
return -1;
} else if (context->error == ResponseContextErrorPromptTooLong) {
return -2;
}
}
}
dictLocker.unlock();
Expand Down
16 changes: 9 additions & 7 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def generate(
for i in range(len(inputs)):
while True:
cur_token = fastllm_lib.fetch_response_llm_model(self.model, handles[i])
if cur_token == -1:
if cur_token <= -1:
break
outputs[i].append(cur_token)
return outputs
Expand Down Expand Up @@ -697,7 +697,7 @@ def stream_response(self,
if not(fastllm_lib.can_fetch_response_llm_model(self.model, handle)):
continue
cur = fastllm_lib.fetch_response_llm_model(self.model, handle)
if (cur == -1):
if (cur <= -1):
break
tokens.append(cur)
ret = tokenizer.decode(tokens)
Expand Down Expand Up @@ -789,8 +789,10 @@ async def stream_response_async(self,
await asyncio.sleep(0)
continue
cur = fastllm_lib.fetch_response_llm_model(self.model, handle)
if (cur == -1):
break;
if (cur <= -1):
if (cur == -2):
yield "prompt too long"
break
tokens.append(cur)
ret = tokenizer.decode(tokens)
if (ret.encode().find(b'\xef\xbf\xbd') == -1):
Expand Down Expand Up @@ -857,7 +859,7 @@ def stream_response_raw(self,
total_bytes = b''
while True:
cur_token = fastllm_lib.fetch_response_llm_model(self.model, handle)
if cur_token == -1:
if cur_token <= -1:
break

cur_bytes = self.tokenizer_decode_token(cur_token)
Expand All @@ -882,7 +884,7 @@ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max
result = [];
while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
if (cur == -1):
if (cur <= -1):
break;
result.append(cur);
response = tokenizer.decode(result);
Expand Down Expand Up @@ -914,7 +916,7 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No
tokens = [];
while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
if (cur == -1):
if (cur <= -1):
break;
tokens.append(cur);
response = tokenizer.decode(tokens);
Expand Down

0 comments on commit b3b3503

Please sign in to comment.