From bf323887e1db1e8b53047810f5eebeddf75e72ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Thu, 20 Jun 2024 17:32:05 +0800 Subject: [PATCH] =?UTF-8?q?api=20server=E7=A8=8B=E5=BA=8F=E6=8F=90?= =?UTF-8?q?=E4=BE=9B=E4=B8=80=E4=B8=AA=E7=AE=80=E6=98=93=E7=9A=84openai=20?= =?UTF-8?q?api=20server=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- example/apiserver/apiserver.cpp | 325 +++++++++++++++++++++++++++----- 1 file changed, 275 insertions(+), 50 deletions(-) diff --git a/example/apiserver/apiserver.cpp b/example/apiserver/apiserver.cpp index b78d84aa..d526a1a2 100644 --- a/example/apiserver/apiserver.cpp +++ b/example/apiserver/apiserver.cpp @@ -121,6 +121,27 @@ using socket_t = int; #include #include "model.h" +long long GetCurrentTime() { + auto now = std::chrono::high_resolution_clock::now(); + auto duration = now.time_since_epoch(); + return std::chrono::duration_cast(duration).count(); +} + +std::string GenerateRandomID() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, 15); + + std::stringstream ss; + for (int i = 0; i < 36; ++i) { + if (i == 8 || i == 13 || i == 18 || i == 23) { + ss << '-'; + } + ss << std::hex << dis(gen); + } + return ss.str(); +} + std::map dataTypeDict = { {"float32", fastllm::DataType::FLOAT32}, {"half", fastllm::DataType::FLOAT16}, @@ -133,15 +154,21 @@ std::map dataTypeDict = { struct APIConfig { std::string path = "chatglm-6b-int4.bin"; // 模型文件路径 - std::string webPath = "web"; // 网页文件路径 + std::string modelName = "fastllm"; + int threads = 4; // 使用的线程数 bool lowMemMode = false; // 是否使用低内存模式 + bool cudaEmbedding = false; // 是否使用cudaEmbedding int port = 8080; // 端口号 int tokens = -1; // token容量限制 int batch = 256; // batch数限制 fastllm::DataType dtype = fastllm::DataType::FLOAT16; + fastllm::DataType atype = fastllm::DataType::FLOAT32; int groupCnt = -1; + + std::map devices; }; +APIConfig config; void ToNext(char * &cur, const std::string &target, std::string &v) { v = ""; @@ -292,62 +319,244 @@ struct WorkQueue { } void Deal(WorkNode *node) { - auto *req = &node->request; - if (req->route != "/generate" || req->method != "POST") { - close(node->client); - return; - } - std::string message = ""; message += "HTTP/1.1 200 OK\r\n"; message += "Content-Type:application/json\r\n"; message += "server:fastllm api server\r\n"; message += "\r\n"; - if (node->error == "") { - if (node->config["prompt"].is_null()) { - node->error = "prompt is empty!"; + auto *req = &node->request; + if ((req->route == "/generate" || req->route == "/generate/") && req->method == "POST") { + if (node->error == "") { + if (node->config["prompt"].is_null()) { + node->error = "prompt is empty!"; + } + } + if (node->error != "") { + printf("error body = %s, prompt = %s, error = %s\n", node->request.body.c_str(), node->config["prompt"].string_value().c_str(), node->error.c_str()); + message += node->error; + int ret = write(node->client, message.c_str(), message.length()); //返回error + close(node->client); + return; + } + + std::string output = ""; + fastllm::ChatMessages messages; + messages.push_back({"user", node->config["prompt"].string_value()}); + auto prompt = model->ApplyChatTemplate(messages); + auto inputs = model->weight.tokenizer.Encode(prompt); + std::vector tokens; + for (int i = 0; i < inputs.Count(0); i++) { + tokens.push_back(((float *) inputs.cpuData)[i]); + } + fastllm::GenerationConfig config; + config.output_token_limit = node->config["max_tokens"].is_null() ? 200 : node->config["max_tokens"].int_value(); + int handleId = model->LaunchResponseTokens(tokens, config); + std::vector results; + while (true) { + int result = model->FetchResponseTokens(handleId); + if (result == -1) { + break; + } else { + results.clear(); + results.push_back(result); + output += model->weight.tokenizer.Decode(fastllm::Data (fastllm::DataType::FLOAT32, {(int)results.size()}, results)); + + std::string cur = (message + output); + int ret = write(node->client, cur.c_str(), cur.length()); //返回message + } } - } - if (node->error != "") { - printf("error body = %s, prompt = %s, error = %s\n", node->request.body.c_str(), node->config["prompt"].string_value().c_str(), node->error.c_str()); - message += node->error; - int ret = write(node->client, message.c_str(), message.length()); //返回error - close(node->client); - return; - } - std::string output = ""; - fastllm::ChatMessages messages; - messages.push_back({"user", node->config["prompt"].string_value()}); - auto prompt = model->ApplyChatTemplate(messages); - auto inputs = model->weight.tokenizer.Encode(prompt); - std::vector tokens; - for (int i = 0; i < inputs.Count(0); i++) { - tokens.push_back(((float *) inputs.cpuData)[i]); - } - fastllm::GenerationConfig config; - config.output_token_limit = node->config["max_tokens"].is_null() ? 200 : node->config["max_tokens"].int_value(); - int handleId = model->LaunchResponseTokens(tokens, config); - std::vector results; - while (true) { - int result = model->FetchResponseTokens(handleId); - if (result == -1) { - break; + message += output; + int ret = write(node->client, message.c_str(), message.length()); //返回message + + close(node->client); + } else if ((req->route == "/v1/chat/completions" || req->route == "/v1/chat/completions/") && req->method == "POST") { + fastllm::ChatMessages chatMessages; + if (node->config["messages"].is_array()) { + for (auto &it : node->config["messages"].array_items()) { + chatMessages.push_back({it["role"].string_value(), it["content"].string_value()}); + } + } else if (node->config["prompt"].is_string()) { + chatMessages.push_back({"user", node->config["prompt"].string_value()}); } else { - results.clear(); - results.push_back(result); - output += model->weight.tokenizer.Decode(fastllm::Data (fastllm::DataType::FLOAT32, {(int)results.size()}, results)); + node->error = "no input.\n"; + } - std::string cur = (message + output); - int ret = write(node->client, cur.c_str(), cur.length()); //返回message + if (node->config["model"].string_value() != ::config.modelName) { + node->error = "The model `" + node->config["model"].string_value() + "` does not exist."; + } + + if (node->error != "") { + message += node->error; + int ret = write(node->client, message.c_str(), message.length()); //返回error + close(node->client); + return; + } + + auto prompt = model->ApplyChatTemplate(chatMessages); + auto inputs = model->weight.tokenizer.Encode(prompt); + std::vector tokens; + for (int i = 0; i < inputs.Count(0); i++) { + tokens.push_back(((float *) inputs.cpuData)[i]); + } + + fastllm::GenerationConfig config; + config.output_token_limit = !node->config["max_tokens"].is_number() ? 256 : node->config["max_tokens"].int_value(); + if (node->config["frequency_penalty"].is_number()) { + config.repeat_penalty = node->config["frequency_penalty"].number_value(); + } + if (node->config["temperature"].is_number()) { + config.temperature = node->config["temperature"].number_value(); + } + if (node->config["top_p"].is_number()) { + config.top_p = node->config["top_p"].number_value(); + } + if (node->config["top_k"].is_number()) { + config.top_k = node->config["top_k"].number_value(); + } + + std::string output = ""; + int handleId = model->LaunchResponseTokens(tokens, config); + bool isStream = false; + if (node->config["stream"].is_bool() && node->config["stream"].bool_value()) { + isStream = true; } - } - message += output; - int ret = write(node->client, message.c_str(), message.length()); //返回message + std::string curId = "fastllm-" + GenerateRandomID(); + auto createTime = GetCurrentTime(); + + if (isStream) { + json11::Json startResult = json11::Json::object { + {"id", curId}, + {"object", "chat.completion.chunk"}, + {"created", createTime}, + {"model", ::config.modelName}, + {"choices", json11::Json::array { + json11::Json::object { + {"index", 0}, + {"delta", json11::Json::object { + {"role", "assistant"} + }}, + {"logprobs", nullptr}, + {"finish_reason", nullptr}, + {"stop_reason", nullptr} + } + }} + }; + std::string cur = (message + "data: " + startResult.dump() + "\r\n"); + int ret = write(node->client, cur.c_str(), cur.length()); //返回初始信息 + + int outputTokens = 0; + std::vector results; + while (true) { + int result = model->FetchResponseTokens(handleId); + if (result == -1) { + json11::Json partResult = json11::Json::object { + {"id", curId}, + {"object", "chat.completion.chunk"}, + {"created", createTime}, + {"model", ::config.modelName}, + {"choices", json11::Json::array { + json11::Json::object { + {"index", 0}, + {"delta", json11::Json::object { + {"content", ""} + }}, + {"logprobs", nullptr}, + {"finish_reason", nullptr}, + {"stop_reason", nullptr} + } + }}, + {"usage", json11::Json::object { + {"prompt_tokens", (int)tokens.size()}, + {"total_tokens", (int)tokens.size() + outputTokens}, + {"completion_tokens", outputTokens} + }} + }; + + std::string cur = ("data: " + partResult.dump() + "\r\n"); + int ret = write(node->client, cur.c_str(), cur.length()); //返回中间信息 + break; + } else { + outputTokens++; + results.clear(); + results.push_back(result); + std::string now = model->weight.tokenizer.Decode(fastllm::Data (fastllm::DataType::FLOAT32, {(int)results.size()}, results)); + json11::Json partResult = json11::Json::object { + {"id", curId}, + {"object", "chat.completion.chunk"}, + {"created", createTime}, + {"model", ::config.modelName}, + {"choices", json11::Json::array { + json11::Json::object { + {"index", 0}, + {"delta", json11::Json::object { + {"content", now} + }}, + {"logprobs", nullptr}, + {"finish_reason", nullptr}, + {"stop_reason", nullptr} + } + }} + }; + + std::string cur = ("data: " + partResult.dump() + "\r\n"); + int ret = write(node->client, cur.c_str(), cur.length()); //返回中间信息 + } + } + + cur = ("data: [DONE]"); + ret = write(node->client, cur.c_str(), cur.length()); //返回message + close(node->client); + } else { + int outputTokens = 0; + std::vector results; + while (true) { + int result = model->FetchResponseTokens(handleId); + if (result == -1) { + break; + } else { + results.clear(); + results.push_back(result); + output += model->weight.tokenizer.Decode(fastllm::Data (fastllm::DataType::FLOAT32, {(int)results.size()}, results)); + outputTokens++; + } + } - close(node->client); + json11::Json result = json11::Json::object { + {"id", curId}, + {"object", "chat.completion"}, + {"created", createTime}, + {"model", ::config.modelName}, + {"choices", json11::Json::array { + json11::Json::object { + {"index", 0}, + {"message", json11::Json::object { + {"role", "assistant"}, + {"content", output} + }}, + {"logprobs", nullptr}, + {"finish_reason", nullptr}, + {"stop_reason", nullptr} + } + }}, + {"usage", json11::Json::object { + {"prompt_tokens", (int)tokens.size()}, + {"total_tokens", (int)tokens.size() + outputTokens}, + {"completion_tokens", outputTokens} + }} + }; + + message += result.dump(); + int ret = write(node->client, message.c_str(), message.length()); //返回message + close(node->client); + } + return; + } else { + close(node->client); + return; + } } } workQueue; @@ -355,13 +564,16 @@ void Usage() { std::cout << "Usage:" << std::endl; std::cout << "[-h|--help]: 显示帮助" << std::endl; std::cout << "<-p|--path> : 模型文件的路径" << std::endl; - std::cout << "<-w|--web> : 网页文件的路径" << std::endl; std::cout << "<-t|--threads> : 使用的线程数量" << std::endl; std::cout << "<-l|--low>: 使用低内存模式" << std::endl; std::cout << "<--dtype> : 设置权重类型(读取hf文件时生效)" << std::endl; - std::cout << "<--batch>: 最大batch数" << std::endl; - std::cout << "<--tokens>: 最大tokens容量" << std::endl; + std::cout << "<--atype> : 设置推理使用的数据类型(float32/float16)" << std::endl; + std::cout << "<--batch> : 最大batch数" << std::endl; + std::cout << "<--tokens> : 最大tokens容量" << std::endl; + std::cout << "<--model_name> : 模型名(openai api中使用)" << std::endl; std::cout << "<--port> : 网页端口号" << std::endl; + std::cout << "<--cuda_embedding>: 使用cuda来执行embedding" << std::endl; + std::cout << "<--device>: 执行设备" << std::endl; } void ParseArgs(int argc, char **argv, APIConfig &config) { @@ -379,8 +591,8 @@ void ParseArgs(int argc, char **argv, APIConfig &config) { config.threads = atoi(sargv[++i].c_str()); } else if (sargv[i] == "-l" || sargv[i] == "--low") { config.lowMemMode = true; - } else if (sargv[i] == "-w" || sargv[i] == "--web") { - config.webPath = sargv[++i]; + } else if (sargv[i] == "--cuda_embedding"){ + config.cudaEmbedding = true; } else if (sargv[i] == "--port") { config.port = atoi(sargv[++i].c_str()); } else if (sargv[i] == "--dtype") { @@ -396,6 +608,15 @@ void ParseArgs(int argc, char **argv, APIConfig &config) { config.tokens = atoi(sargv[++i].c_str()); } else if (sargv[i] == "--batch") { config.batch = atoi(sargv[++i].c_str()); + } else if (sargv[i] == "--atype") { + std::string atypeStr = sargv[++i]; + fastllm::AssertInFastLLM(dataTypeDict.find(atypeStr) != dataTypeDict.end(), + "Unsupport act type: " + atypeStr); + config.atype = dataTypeDict[atypeStr]; + } else if (sargv[i] == "--model_name") { + config.modelName = sargv[++i]; + } else if (sargv[i] == "--device") { + config.devices[sargv[++i]] = 1; } else { Usage(); exit(-1); @@ -408,11 +629,14 @@ std::string url = "generate"; std::mutex locker; int main(int argc, char** argv) { - APIConfig config; ParseArgs(argc, argv, config); + if (config.devices.size() != 0) { + fastllm::SetDeviceMap(config.devices); + } fastllm::SetThreads(config.threads); fastllm::SetLowMemMode(config.lowMemMode); + fastllm::SetCudaEmbedding(config.cudaEmbedding); if (!fastllm::FileExists(config.path)) { printf("模型文件 %s 不存在!\n", config.path.c_str()); exit(0); @@ -421,6 +645,7 @@ int main(int argc, char** argv) { workQueue.model = isHFDir ? fastllm::CreateLLMModelFromHF(config.path, config.dtype, config.groupCnt) : fastllm::CreateLLMModelFromFile(config.path); workQueue.model->tokensLimit = config.tokens; + workQueue.model->SetDataType(config.atype); workQueue.maxActivateQueryNumber = std::max(1, std::min(256, config.batch)); workQueue.Start();