Skip to content

Commit

Permalink
Merge pull request #300 from yuanphoenix/master
Browse files Browse the repository at this point in the history
修复时间戳重复的问题
  • Loading branch information
ztxz16 authored Sep 27, 2023
2 parents 73c15c0 + 6df1af0 commit 60de06f
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 13 deletions.
7 changes: 4 additions & 3 deletions pyfastllm/demo/web_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import traceback
from typing import List
sys.path.append('../../build-py')
import pyfastllm # 或fastllm
import pyfastllm
import uuid
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import threading, queue, uvicorn, json, time
Expand Down Expand Up @@ -106,8 +107,8 @@ def dynamic_batch_stream_func():


def chat_stream(prompt: str, config: pyfastllm.GenerationConfig, uid:int=0, time_out=200):
global g_model, g_msg_dict
time_stamp = round(time.time() * 1000)
global g_msg_dict
time_stamp = str(uuid.uuid1())
hash_id = str(pyfastllm.std_hash(f"{prompt}time_stamp:{time_stamp}"))
thread = threading.Thread(target = batch_response_stream, args = (f"{prompt}time_stamp:{time_stamp}", config))
thread.start()
Expand Down
8 changes: 4 additions & 4 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ namespace fastllm {
#endif
std::string prompt = input;
#ifdef PY_API
size_t pos = input.find_last_of("time_stamp:");
prompt = (generationConfig.enable_hash_id && pos != std::string::npos) ? input.substr(0, pos - 10) : input;
size_t pos = input.rfind("time_stamp:");
prompt = (generationConfig.enable_hash_id && pos != -1) ? input.substr(0, pos) : input;
size_t hash_id = std::hash<std::string>{}(input);
#endif
Data inputIds, attentionMask, positionIds;
Expand Down Expand Up @@ -151,8 +151,8 @@ namespace fastllm {
size_t hash_id = std::hash<std::string>{}(_input);
hash_ids.push_back(hash_id);

size_t pos = _input.find_last_of("time_stamp:");
std::string prompt = (generationConfig.enable_hash_id && pos != std::string::npos) ? _input.substr(0, pos - 10) : _input;
size_t pos = _input.rfind("time_stamp:");
std::string prompt = (generationConfig.enable_hash_id && pos != -1) ? _input.substr(0, pos) : _input;
prompts.push_back(prompt);
}
#else
Expand Down
8 changes: 4 additions & 4 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,8 @@ namespace fastllm {
#endif
//auto st = std::chrono::system_clock::now();
#ifdef PY_API
size_t pos = input.find_last_of("time_stamp:");
std::string prompt = (generationConfig.enable_hash_id && pos != std::string::npos)? input.substr(0, pos-10):input;
size_t pos = input.rfind("time_stamp:");
std::string prompt = (generationConfig.enable_hash_id && pos != -1)? input.substr(0, pos):input;
size_t hash_id = std::hash<std::string>{}(input);
Data inputIds = this->weight.tokenizer.Encode(prompt);
#else
Expand Down Expand Up @@ -681,8 +681,8 @@ namespace fastllm {
size_t hash_id = std::hash<std::string>{}(_input);
hash_ids.push_back(hash_id);

size_t pos = _input.find_last_of("time_stamp:");
std::string prompt = (generationConfig.enable_hash_id && pos != std::string::npos) ? _input.substr(0, pos - 10) : _input;
size_t pos = _input.rfind("time_stamp:");
std::string prompt = (generationConfig.enable_hash_id && pos != -1) ? _input.substr(0, pos) : _input;
prompts.push_back(prompt);
}
#else
Expand Down
4 changes: 2 additions & 2 deletions src/models/moss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ namespace fastllm {
RuntimeResult retCb,
const GenerationConfig &generationConfig) {
#ifdef PY_API
size_t pos = input.find_last_of("time_stamp:");
std::string prompt = (generationConfig.enable_hash_id && pos != std::string::npos)? input.substr(0, pos-10):input;
size_t pos = input.rfind("time_stamp:");
std::string prompt = (generationConfig.enable_hash_id && pos != -1)? input.substr(0, pos):input;
size_t hash_id = std::hash<std::string>{}(input);
Data inputIds = this->weight.tokenizer.Encode(prompt);
#else
Expand Down

0 comments on commit 60de06f

Please sign in to comment.