From 67eba1120f6ec940812c2031e87d5105967ec650 Mon Sep 17 00:00:00 2001 From: cgli Date: Sun, 21 Jul 2024 13:36:50 +0800 Subject: [PATCH] =?UTF-8?q?Jinja=E6=A8=A1=E6=9D=BF=E6=94=AF=E6=8C=81"is=20?= =?UTF-8?q?defined"=E5=92=8C=E2=80=9Celif=E2=80=9D=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81DeepSeek=20V2=E7=B3=BB=E5=88=97=E6=A8=A1=E6=9D=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/models.md | 24 +- include/template.h | 6 +- src/models/basellm.cpp | 2166 ++++++++++++++++++++-------------------- src/template.cpp | 23 +- 4 files changed, 1121 insertions(+), 1098 deletions(-) diff --git a/docs/models.md b/docs/models.md index 9e19ac6c..69b85327 100644 --- a/docs/models.md +++ b/docs/models.md @@ -68,20 +68,24 @@ | Qwen/Qwen2-7B-Instruct | [✔](#其它模型) | [✔](#qwen模型导出) | ✔ | | Qwen/Qwen2-72B-Instruct | | [✔](#qwen模型导出) | ✔ | -> 注3: 需要更新,检查 tokenizer_config.json 是否为最新版本 +> 注3: 需要更新,检查 `tokenizer_config.json` 是否为最新版本 ### DeepSeek系列 | 模型 | 加载后转换 | 离线转换 | 直接读取 | |-------------------------------------------: |------------|------------|------------| -| deepseek-ai/Deepseek-Coder-1.3B-Instruct | [✔](llama_cookbook.md#deepseek-coder) | [✔](llama_cookbook.md#deepseek-coder) | ✔ | -| deepseek-ai/Deepseek-Coder-6.7B-Instruct | [✔](llama_cookbook.md#deepseek-coder) | [✔](llama_cookbook.md#deepseek-coder) | ✔ | -| deepseek-ai/Deepseek-Coder-7B-Instruct v1.5 | [✔](llama_cookbook.md#deepseek-coder) | [✔](llama_cookbook.md#deepseek-coder) | ❌4 | -| deepseek-ai/deepseek-coder-33b-instruct | [√](llama_cookbook.md#deepseek-coder) | [√](llama_cookbook.md#deepseek-coder) | ❌4 | -| deepseek-ai/DeepSeek-V2-Chat | √ | ✔ | √4 | -| deepseek-ai/DeepSeek-V2-Lite-Chat | √ | ✔ | √4 | -| deepseek-ai/DeepSeek-Coder-V2-Instruct | √ | √ | √4 | -| deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct | √ | √ | √4 | +| deepseek-ai/Deepseek-Coder-1.3B-Instruct | [✔](llama_cookbook.md#deepseek-coder) | [✔](llama_cookbook.md#deepseek-coder) | ✔45 | +| deepseek-ai/Deepseek-Coder-6.7B-Instruct | [✔](llama_cookbook.md#deepseek-coder) | [✔](llama_cookbook.md#deepseek-coder) | ✔45 | +| deepseek-ai/Deepseek-Coder-7B-Instruct v1.5 | [✔](llama_cookbook.md#deepseek-coder) | [✔](llama_cookbook.md#deepseek-coder) | ✔4 | +| deepseek-ai/deepseek-coder-33b-instruct | [√](llama_cookbook.md#deepseek-coder) | [√](llama_cookbook.md#deepseek-coder) | ✔4 | +| deepseek-ai/DeepSeek-V2-Chat | √ | ✔ | √ | +| deepseek-ai/DeepSeek-V2-Lite-Chat | √ | ✔ | ✔ | +| deepseek-ai/DeepSeek-Coder-V2-Instruct | √ | ✔ | √ | +| deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct | √ | ✔ | ✔ | + +> 注4: Python ftllm用AutoTokenizer而不使用Fastllm Tokenizer可以实现加载,但是C++程序尚不支持加载该模型的Tokenizer。 +> 注5: C++端仅支持最早的几个 `tokenizer_config.json` 版本 + ### LLaMA类模型 @@ -107,8 +111,6 @@ | meta-llama/Meta-Llama-3-8B-Instruct | | [✔](tools/scripts/llama3_to_flm.py) | ✔ | | meta-llama/Meta-Llama-3-70B-Instruct | | [✔](tools/scripts/llama3_to_flm.py) | ✔ | -> 注4: Python ftllm用AutoTokenizer而不使用Fastllm Tokenizer可以实现加载,但是C++程序尚不支持加载该模型的Tokenizer。 - ### 其它模型 | 模型 | 加载后转换 | 离线转换 | 直接读取 | diff --git a/include/template.h b/include/template.h index 3db6db79..05bfeb12 100644 --- a/include/template.h +++ b/include/template.h @@ -51,7 +51,7 @@ namespace fastllm { enum JinjaToKenType { JinjaTokenID = 0, JinjaTokenBOOL, JinjaTokenNUM, JinjaTokenSTRING, JinjaTokenDOT, JinjaTokenLMB, JinjaTokenRMB, JinjaTokenLSB, JinjaTokenRSB, - JinjaTokenSet, JinjaTokenFor, JinjaTokenEndFor, JinjaTokenIf, JinjaTokenElse, JinjaTokenEndif, + JinjaTokenSet, JinjaTokenFor, JinjaTokenEndFor, JinjaTokenIf, JinjaTokenElse, JinjaTokenElseIf, JinjaTokenEndif, JinjaTokenIn, JinjaTokenAssign, JinjaTokenNotEqual, JinjaTokenEqual, JinjaTokenAdd, JinjaTokenSub, JinjaTokenMul, JinjaTokenDiv, JinjaTokenNot, JinjaTokenAnd, JinjaTokenOr, @@ -86,10 +86,12 @@ namespace fastllm { {"for", JinjaToken::JinjaToKenType::JinjaTokenFor}, {"endfor", JinjaToken::JinjaToKenType::JinjaTokenEndFor}, {"if", JinjaToken::JinjaToKenType::JinjaTokenIf}, + {"elif", JinjaToken::JinjaToKenType::JinjaTokenElseIf}, {"else", JinjaToken::JinjaToKenType::JinjaTokenElse}, {"endif", JinjaToken::JinjaToKenType::JinjaTokenEndif}, {"set", JinjaToken::JinjaToKenType::JinjaTokenSet}, {"in", JinjaToken::JinjaToKenType::JinjaTokenIn}, + {"is", JinjaToken::JinjaToKenType::JinjaTokenIn}, {"true", JinjaToken::JinjaToKenType::JinjaTokenBOOL}, {"false", JinjaToken::JinjaToKenType::JinjaTokenBOOL}, {"and", JinjaToken::JinjaToKenType::JinjaTokenAnd}, @@ -101,7 +103,7 @@ namespace fastllm { struct JinjaBlock { enum JinjaBlockType { JinjaBlockOriginal = 0, JinjaBlockEmpty, JinjaBlockVar, JinjaBlockFor, - JinjaBlockEndFor, JinjaBlockIf, JinjaBlockElse, JinjaBlockEndIf, + JinjaBlockEndFor, JinjaBlockIf, JinjaBlockElseIf, JinjaBlockElse, JinjaBlockEndIf, JinjaBlockSet }; diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index 50487c91..ab1f3862 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -1,1082 +1,1084 @@ -// -// Created by huangyuyang on 6/25/23. -// - -#include "basellm.h" -#include "utils.h" -#include -#include - -#ifdef USE_CUDA -#include "fastllm-cuda.cuh" -#endif - -namespace fastllm { - int ResponseContextDict::CreateHandle() { - locker.lock(); - int newId = 0; - while (dicts.find(newId) != dicts.end()) { - newId++; - } - dicts[newId] = new ResponseContext(); - locker.unlock(); - return newId; - } - - ResponseContext *ResponseContextDict::GetHandle(int handleId) { - locker.lock(); - ResponseContext *ret = dicts.find(handleId) != dicts.end() ? dicts[handleId] : nullptr; - locker.unlock(); - return ret; - } - - void ResponseContextDict::RemoveHandle(int handleId) { - locker.lock(); - if (dicts.find(handleId) != dicts.end()) { - delete dicts[handleId]; - dicts.erase(handleId); - } - locker.unlock(); - } - - void ResponseContext::Init(int blocks, DataType dataType) { - pastKeyValues.clear(); - for (int i = 0; i < blocks; i++) { - pastKeyValues.push_back(std::make_pair(Data(dataType), - Data(dataType))); - pastKeyValues.back().first.SetKVCache(); - pastKeyValues.back().second.SetKVCache(); - } - intParams.clear(); - currentTokens.clear(); - while (resultTokenQueue.size() > 0){ - resultTokenQueue.pop(); - } - isEnding = false; - preTokens = 0; - } - - PastKVCacheMemory::PastKVCacheMemory(const std::vector &inputToken, int tokens, long long flushTime, std::vector > *kv) { - this->inputToken = inputToken; - this->tokens = tokens; - this->flushTime = flushTime; - this->recordTimes = 1; - auto dataType = (*kv)[0].first.dataType; - for (int i = 0; i < kv->size(); i++) { - this->kv.push_back(std::make_pair(Data(dataType), Data(dataType))); - } - for (int i = 0; i < kv->size(); i++) { - this->kv[i].first.CopyFrom((*kv)[i].first); - this->kv[i].second.CopyFrom((*kv)[i].second); - } - } - - void PastKVCacheManager::SetMaxRecordNum(int maxRecordNum) { - std::lock_guard lock(this->locker); - this->maxRecordNum = maxRecordNum; - } - - void PastKVCacheManager::Record(const std::vector &inputToken, int tokens, std::vector > *kv) { - std::lock_guard lock(this->locker); - if (this->memorys.find(inputToken) != this->memorys.end()) { - this->memorys[inputToken]->recordTimes++; - this->memorys[inputToken]->flushTime = ++flushTime; - return; - } - - if (this->memorys.size() >= this->maxRecordNum) { - std::vector eraseToken; - long long minFlushTime = (1LL << 60); - for (auto &it : this->memorys) { - if (it.second->flushTime < minFlushTime) { - minFlushTime = it.second->flushTime; - eraseToken = it.first; - } - } - delete this->memorys[eraseToken]; - this->memorys.erase(this->memorys.find(eraseToken)); - } - - this->memorys[inputToken] = new PastKVCacheMemory(inputToken, tokens, ++flushTime, kv); - } - - void PastKVCacheManager::Remove(const std::vector &inputToken) { - std::lock_guard lock(this->locker); - if (this->memorys.find(inputToken) != this->memorys.end()) { - if ((--this->memorys[inputToken]->recordTimes) <= 0) { - delete this->memorys[inputToken]; - this->memorys.erase(this->memorys.find(inputToken)); - } - } - } - - PastKVCacheMemory *PastKVCacheManager::Get(const std::vector &inputToken) { - std::lock_guard lock(this->locker); - std::vector maxToken; - for (auto &it : this->memorys) { - const std::vector &cur = it.first; - if (cur.size() > maxToken.size() && cur.size() <= inputToken.size()) { - bool match = true; - for (int i = 0; i < cur.size(); i++) { - if (inputToken[i] != cur[i]) { - match = false; - break; - } - } - if (match) { - maxToken = cur; - } - } - } - if (maxToken.size() == 0) { - return nullptr; - } - this->memorys[maxToken]->flushTime = ++this->flushTime; - return this->memorys[maxToken]; - } - - void PastKVCacheManager::Unlock() { - locker.unlock(); - } - - basellm::~basellm() { - dictLocker.lock(); - this->isFree = true; - dictLocker.unlock(); - dictCV.notify_all(); - this->weight.ReleaseWeight(); - } - - std::map > > - basellm::GetTensorMap(const std::vector &tensorNames) { - std::map > > ret; - for (auto &name : tensorNames) { - WeightType weightType = this->weight.GetWeightType(name); - DataType dataType = DataType::DATA_AUTO_NONE; - if (weightType == WeightType::LINEAR) { - dataType = DataType::DATA_AUTO_LINEAR; - } else if (weightType == WeightType::EMBEDDING) { - dataType = DataType::DATA_AUTO_EMBEDDING; - } - ret[name].push_back(std::make_pair(name, dataType)); - } - return ret; - } - - std::string basellm::Response(const std::string &oriInput, RuntimeResult retCb, - const fastllm::GenerationConfig &generationConfig) { - std::string input = oriInput; - if (this->saveHistoryChat) { - if (lastKeyValues != nullptr) { - if (input.size() < lastPrompt.size() || (input.substr(0, lastPrompt.size()) != lastPrompt)) { - lastPrompt = ""; - lastPromptTokens = 0; - delete lastKeyValues; - lastKeyValues = nullptr; - } else { - input = input.substr(lastPrompt.size()); - } - } - } else { - lastPrompt = ""; - lastPromptTokens = 0; - delete lastKeyValues; - lastKeyValues = nullptr; - } - - //printf("lastPrompt = %s\n", lastPrompt.c_str()); - //printf("input = %s\n", input.c_str()); - -#ifdef USE_CUDA - FastllmCudaClearBigBuffer(); -#endif - std::string prompt = input; -#ifdef PY_API - size_t pos = input.rfind("time_stamp:"); - prompt = (generationConfig.enable_hash_id && pos != -1) ? input.substr(0, pos) : input; - size_t hash_id = std::hash{}(input); -#endif - Data inputIds, attentionMask, positionIds; - - Data inputTokenData = this->weight.tokenizer.Encode(prompt); - std::vector > inputTokens; - inputTokens.resize(1); - for (int i = 0; i < inputTokenData.Count(0); i++) { - inputTokens[0].push_back(((float *) inputTokenData.cpuData)[i]); - } - - if (lastKeyValues == nullptr) { - lastKeyValues = new std::vector >(); - for (int i = 0; i < block_cnt; i++) { - lastKeyValues->push_back(std::make_pair(Data(this->dataType), Data(this->dataType))); - lastKeyValues->back().first.SetKVCache(); - lastKeyValues->back().second.SetKVCache(); - } - } - - std::vector > &pastKeyValues = (*lastKeyValues); - std::string retString = ""; - std::vector results; - LastTokensManager tokens(1, generationConfig.last_n); - int promptLen = lastPromptTokens + inputTokens[0].size(), index = 0; - int add_special_tokens = generationConfig.add_special_tokens? 1: 0; - FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}}, - inputIds, attentionMask, positionIds); - ToDataType(attentionMask, this->dataType); - while (true) { - auto st = std::chrono::system_clock::now(); - int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens); - tokens.units[0].Push(ret); - if (ret == eos_token_id - || generationConfig.stop_token_ids.find(ret) != generationConfig.stop_token_ids.end() - || eos_token_ids.find(ret) != eos_token_ids.end()) { - break; - } - - results.push_back(ret); - std::string curString = weight.tokenizer.Decode( - Data(DataType::FLOAT32, {(int) results.size()}, results)).c_str(); - retString += curString; - if (retCb) -#ifdef PY_API - { - if (generationConfig.enable_hash_id) { - std::stringstream ss; - ss << retString << "hash_id:"< {(float)ret}; - FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}}, - inputIds, attentionMask, positionIds); - ToDataType(attentionMask, this->dataType); - if (index == generationConfig.output_token_limit) { - break; - } - // printf("len = %d, spend %f s.\n", len, GetSpan(st, std::chrono::system_clock::now())); - } - if (retCb) -#ifdef PY_API - { - if(generationConfig.enable_hash_id){ - std::stringstream ss; - ss << retString << "hash_id:"< &inputs, std::vector &outputs, - RuntimeResultBatch retCb, const fastllm::GenerationConfig &generationConfig) { -#ifdef USE_CUDA - FastllmCudaClearBigBuffer(); -#endif -#ifdef PY_API - std::vector prompts; - std::vector < size_t > hash_ids; - for (auto _input: inputs){ - size_t hash_id = std::hash{}(_input); - hash_ids.push_back(hash_id); - - size_t pos = _input.rfind("time_stamp:"); - std::string prompt = (generationConfig.enable_hash_id && pos != -1) ? _input.substr(0, pos) : _input; - prompts.push_back(prompt); - } -#else - std::vector prompts = inputs; -#endif - // 1. first - Data inputIds, attentionMask, positionIds; - - int batch = prompts.size(); - outputs.clear(); - outputs.resize(batch, ""); - - std::vector > inputTokens; - inputTokens.resize(batch); - - for (int i = 0; i < batch; i++) { - Data now = this->weight.tokenizer.Encode(prompts[i]); - for (int j = 0; j < now.Count(0); j++) { - inputTokens[i].push_back(((float *) now.cpuData)[j]); - } - } - - std::vector > pastKeyValues; - for (int i = 0; i < block_cnt; i++) { - pastKeyValues.push_back(std::make_pair(Data(dataType), - Data(dataType))); - pastKeyValues.back().first.SetKVCache(); - pastKeyValues.back().second.SetKVCache(); - } - - std::vector > params; - params.resize(batch); - for (int i = 0; i < batch; i++) { - params[i]["promptLen"] = (int)inputTokens[i].size(); - } - params[0]["index"] = 0; - int index = 0; - params[0]["add_special_tokens"] = generationConfig.add_special_tokens? 1: 0; - - LastTokensManager tokensManager (batch, generationConfig.last_n); - std::vector isEnding = std::vector (batch, false); - FillLLMInputsBatch(inputTokens, params, inputIds, attentionMask, positionIds); - ToDataType(attentionMask, this->dataType); - while (true) { - auto st = std::chrono::system_clock::now(); -// ClearProfiler(); - std::vector ret = ForwardBatch(batch, inputIds, attentionMask, positionIds, pastKeyValues, - generationConfig, tokensManager); -// PrintProfiler(); - for (int i = 0; i < batch; i++) { - tokensManager.units[i].Push(ret[i]); - } - std::vector fret; - std::vector results; - int endingCount = 0; - std::vector curStrings; - for (int i = 0; i < batch; i++) { - fret.push_back(ret[i]); - inputTokens[i] = std::vector {(float)ret[i]}; - if (ret[i] == eos_token_id || eos_token_ids.find(ret[i]) != eos_token_ids.end()) { - isEnding[i] = true; - } else { - auto itStopTk = generationConfig.stop_token_ids.find(ret[i]); - if (itStopTk != generationConfig.stop_token_ids.end()) { - isEnding[i] = true; - } - } - if (isEnding[i]) { - curStrings.push_back(""); - endingCount++; - continue; - } - results.push_back(ret[i]); - std::string curString = weight.tokenizer.Decode( - Data(DataType::FLOAT32, {(int) results.size()}, results)).c_str(); - outputs[i] += curString; - curStrings.push_back(curString); - results.clear(); - } - - if (endingCount == batch) { - break; - } - if (retCb) -#ifdef PY_API - { - if (generationConfig.enable_hash_id) { - std::vector rtnStrings; - for (size_t i=0; i rtnStrings; - for (size_t i=0; idataType); - // printf("len = %d, spend %f s.\n", len, GetSpan(st, std::chrono::system_clock::now())); - - if (index == generationConfig.output_token_limit) { - break; - } - } - if (retCb) -#ifdef PY_API - { - if (generationConfig.enable_hash_id) { - std::vector rtnStrings; - for (size_t i=0; i rtnStrings; - for (size_t i=0; i basellm::ForwardBatch(int batch, const fastllm::Data &inputIds, const fastllm::Data &attentionMask, - const fastllm::Data &positionIds, - std::vector> &pastKeyValues, - const fastllm::GenerationConfig &generationConfig, - const fastllm::LastTokensManager &lastTokens, - std::vector *> *retLogits) { - printf("Unsupport forward batch.\n"); - exit(0); - } - - std::vector basellm::ForwardBatch(int batch, const fastllm::Data &inputIds, - const std::vector &attentionMask, - const std::vector &positionIds, const std::vector &seqLens, - std::vector> &pastKeyValues, - const std::vector &generationConfigs, - const fastllm::LastTokensManager &lastTokens, - std::vector *> *logits) { - std::vector ret; - int cur = 0; - for (int i = 0; i < batch; i++) { - std::vector > curKV; - curKV.resize(this->block_cnt); - for (int j = 0; j < this->block_cnt; j++) { - Mul(*pastKeyValues[i * this->block_cnt + j].first, 1.0, curKV[j].first); - Mul(*pastKeyValues[i * this->block_cnt + j].second, 1.0, curKV[j].second); - } - Data curInput; - Split(inputIds, 1, cur, cur + seqLens[i], curInput); - LastTokensManager curTokens; - curTokens.units.push_back(lastTokens.units[i]); - ret.push_back(this->Forward(curInput, *attentionMask[i], *positionIds[i], curKV, generationConfigs[i], curTokens)); - for (int j = 0; j < this->block_cnt; j++) { - Mul(curKV[j].first, 1.0, *pastKeyValues[i * this->block_cnt + j].first); - Mul(curKV[j].second, 1.0, *pastKeyValues[i * this->block_cnt + j].second); - } - } - return ret; - } - - int basellm::LaunchResponseTokens(const std::vector &inputTokens, - const fastllm::GenerationConfig &generationConfig) { - mainLoopLocker.lock(); - if (mainLoop == nullptr) { - if (mainLoop == nullptr) { - mainLoop = new std::thread([](basellm *model) { - long long kvCacheLimit = 16LL << 30; -#ifdef USE_CUDA - auto freeSizes = FastllmCudaGetFreeSizes(); - kvCacheLimit = 0; - for (long long i : freeSizes) { - kvCacheLimit += std::max(0LL, i - (2LL << 30)); - } -#endif - if (model->kvCacheLimit > 0) { - kvCacheLimit = model->kvCacheLimit; - } - - int unitSize = (model->dataType == DataType::FLOAT32 ? 4 : 2); - int maxTotalLens = kvCacheLimit / (model->elementsInKVCachePerToken * unitSize); - if (model->elementsInKVCachePerToken <= 0) { - maxTotalLens = kvCacheLimit / 1024 / 1024; - } - if (model->tokensLimit > 0) { - maxTotalLens = model->tokensLimit; - } - - int maxBatch = std::max(1, std::min(512, maxTotalLens / 128)); - if (model->maxBatch > 0) { - maxBatch = model->maxBatch; - } - - model->tokensLimit = maxTotalLens; - if (model->verbose) { - printf("Fastllm KV Cache Limit: %f MB.\n", (double)kvCacheLimit / 1024 / 1024); - printf("Fastllm KV Cache Token limit: %d tokens.\n", maxTotalLens); - printf("Fastllm Batch limit: %d.\n", maxBatch); - } - - auto lastRecordTime = std::chrono::system_clock::now(); - long long genTokens = 0; - while (true) { - if (model->isFree) { - break; - } - std::vector attentionMasks; - std::vector positionIds; - std::vector > pastKeyValues; - std::vector ids; - std::vector seqLens; - std::vector handles; - std::vector generationConfigs; - LastTokensManager tokensManager; - std::vector * > logits; - - std::unique_lock dictLocker(model->dictLocker); - - int limit = maxTotalLens; - int promptLimit = limit * 2 / 3; - - int lenSum = 0; - for (auto &it: model->responseContextDict.dicts) { - if (it.second->pastKeyValues[0].first.expansionDims.size() > 0) { - lenSum += it.second->pastKeyValues[0].first.expansionDims[1]; - } - } - - for (int isPrompt = 1; isPrompt >= 0; isPrompt--) { - int cnt = 0; - if (isPrompt == 0 && seqLens.size() > 0) { - continue; - } - if (lenSum >= promptLimit && isPrompt) { - continue; - } - - for (auto &it: model->responseContextDict.dicts) { - if (it.second->isEnding) { - continue; - } - if (isPrompt && it.second->preTokens != 0) { - continue; - } - if (!isPrompt && it.second->preTokens == 0) { - continue; - } - - int outputLimit = it.second->generationConfig.output_token_limit; - outputLimit = (outputLimit < 0 ? 128 : outputLimit); - if (isPrompt && lenSum + it.second->currentTokens.size() > promptLimit) { - continue; - } - - if (!isPrompt) { - if (it.second->pastKeyValues[0].first.expansionDims[1] == it.second->pastKeyValues[0].first.dims[1]) { - int sur = it.second->generationConfig.output_token_limit - it.second->curTokens; - int predictLen = 256; - if (sur > 0) { - predictLen = std::min(predictLen, ((sur - 1) / 128 + 1) * 128); - } - if (lenSum + predictLen > limit) { - continue; - } - lenSum += predictLen; - } - } - - generationConfigs.push_back(it.second->generationConfig); - if (it.second->generationConfig.output_logits) { - it.second->resultLogits.push(new std::vector()); - logits.push_back(it.second->resultLogits.back()); - } else { - logits.push_back(nullptr); - } - - tokensManager.units.push_back(it.second->tokens); - handles.push_back(it.first); - - if (it.second->preTokens == 0) { - it.second->intParams["add_special_tokens"] = it.second->cacheLen > 0 ? false : it.second->generationConfig.add_special_tokens; - it.second->intParams["promptLen"] = it.second->cacheLen + it.second->currentTokens.size(); - it.second->intParams["index"] = 0; - } else { - it.second->intParams["index"]++; - } - Data inputIds, attentionMask, curPositionIds; - std::vector > tokens; - tokens.resize(1); - for (int i: it.second->currentTokens) { - tokens[0].push_back(i); - } - model->FillLLMInputs(tokens, it.second->intParams, inputIds, attentionMask, curPositionIds); - ToDataType(attentionMask, model->dataType); - - seqLens.push_back(inputIds.Count(0)); - for (int i = 0; i < inputIds.Count(0); i++) { - ids.push_back(((float *) inputIds.cpuData)[i]); - } - if (attentionMask.dims.size() == 0) { - attentionMasks.push_back(nullptr); - } else { - attentionMasks.push_back(new Data()); - attentionMask.ToDevice(DataDevice::CPU); - attentionMasks.back()->CopyFrom(attentionMask); - } - if (curPositionIds.dims.size() == 0) { - positionIds.push_back(nullptr); - } else { - positionIds.push_back(new Data()); - positionIds.back()->CopyFrom(curPositionIds); - } - it.second->preTokens += seqLens.back(); - for (int i = 0; i < model->block_cnt; i++) { - pastKeyValues.push_back(std::make_pair(&it.second->pastKeyValues[i].first, - &it.second->pastKeyValues[i].second)); - } - if (isPrompt) { - cnt += it.second->currentTokens.size(); - - if (cnt > 1024) { - break; - } - // break; - } - - if (seqLens.size() >= maxBatch || lenSum + seqLens.size() * 128 > limit) { - break; - } - } - } - if (seqLens.size() > 0) { - std::vector > *pastKeyValue1; - if (seqLens.size() == 1) { - pastKeyValue1 = &model->responseContextDict.dicts[handles[0]]->pastKeyValues; - } - dictLocker.unlock(); -#ifdef USE_CUDA - FastllmCudaClearBigBuffer(); -#endif - Data inputIds = Data(DataType::FLOAT32, {1, (int) ids.size()}, ids); - std::vector ret; -auto st = std::chrono::system_clock::now(); -//ClearProfiler(); - if (seqLens.size() > 1) { - ret = model->ForwardBatch(seqLens.size(), inputIds, attentionMasks, - positionIds, seqLens, pastKeyValues, generationConfigs, - tokensManager, &logits); - } else { - if (seqLens[0] > 8192) { - int len = seqLens[0]; - int first = 8192, part = 2048; - for (int st = 0; st < len; ) { - int curLen = std::min(st == 0 ? first : part, len - st); - Data curInput, curPositionIds; - Split(inputIds, 1, st, st + curLen, curInput); - Split(*positionIds[0], 1, st, st + curLen, curPositionIds); - - ret = std::vector {model->Forward(curInput, Data(), curPositionIds, - *pastKeyValue1, generationConfigs[0], tokensManager, logits[0])}; - st += curLen; - } - } else { - ret = std::vector {model->Forward(inputIds, - attentionMasks[0] == nullptr ? Data() : *attentionMasks[0], - *positionIds[0], - *pastKeyValue1, generationConfigs[0], tokensManager, logits[0])}; - } - } - - -//PrintProfiler(); -/*int total = 0; -for (int i : seqLens) total += i; -float spend = GetSpan(st, std::chrono::system_clock::now()); -printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)total / spend); -*/ - if (model->verbose) { - genTokens += seqLens.size(); - auto nowTime = std::chrono::system_clock::now(); - float spend = GetSpan(lastRecordTime, nowTime); - if (spend > 1) { - printf("Current batch: %d, Speed: %f tokens / s.\n", (int)seqLens.size(), (float)genTokens / spend); - lastRecordTime = nowTime; - genTokens = 0; - } - } - - dictLocker.lock(); - for (int i = 0; i < handles.size(); i++) { - auto &it = *model->responseContextDict.dicts.find(handles[i]); - int curRet = ret[i]; - if (curRet == model->eos_token_id || model->eos_token_ids.find(curRet) != model->eos_token_ids.end()) { - it.second->isEnding = true; - } else { - auto itStopTk = it.second->generationConfig.stop_token_ids.find(curRet); - if (itStopTk != it.second->generationConfig.stop_token_ids.end()) { - it.second->isEnding = true; - } - } - if (it.second->isEnding == false) { - it.second->currentTokens = std::vector{curRet}; - it.second->resultTokenQueue.push(curRet); - it.second->tokens.Push(curRet); - it.second->curTokens++; - if (it.second->curTokens == it.second->generationConfig.output_token_limit) { - it.second->isEnding = true; - } - } - } - } - - for (int i = 0; i < attentionMasks.size(); i++) { - delete attentionMasks[i]; - } - for (int i = 0; i < positionIds.size(); i++) { - delete positionIds[i]; - } - - if (seqLens.size() == 0) { - model->dictCV.wait(dictLocker); - } - } - }, this); - } - } - mainLoopLocker.unlock(); - - dictLocker.lock(); - int handleId = responseContextDict.CreateHandle(); - ResponseContext *context = responseContextDict.GetHandle(handleId); - context->Init(this->block_cnt, this->dataType); - context->currentTokens = inputTokens; - context->generationConfig = generationConfig; - context->tokens = LastTokensUnit(generationConfig.last_n); - - auto cache = pastKVCacheManager.Get(inputTokens); - if (cache != nullptr) { - for (int i = 0; i < this->block_cnt; i++) { - context->pastKeyValues[i].first.CopyFrom(cache->kv[i].first); - context->pastKeyValues[i].second.CopyFrom(cache->kv[i].second); - } - context->currentTokens.erase(context->currentTokens.begin(), context->currentTokens.begin() + cache->inputToken.size()); - context->cacheLen = cache->inputToken.size(); - } - - dictLocker.unlock(); - dictCV.notify_one(); - return handleId; - } - - bool basellm::CanFetchResponse(int handleId) { - std::unique_lock dictLocker(this->dictLocker); - ResponseContext *context = responseContextDict.GetHandle(handleId); - if (context == nullptr) { - return true; - } else { - return (context->resultTokenQueue.size() > 0 || context->isEnding); - } - } - - int basellm::FetchResponseTokens(int handleId) { - std::unique_lock dictLocker(this->dictLocker); - ResponseContext *context = responseContextDict.GetHandle(handleId); - if (context == nullptr) { - return -1; - } else { - while (true) { - if (context->resultTokenQueue.size() > 0) { - int ret = context->resultTokenQueue.front(); - context->resultTokenQueue.pop(); - return ret; - } else { - if (context->isEnding) { - responseContextDict.RemoveHandle(handleId); - return -1; - } - } - dictLocker.unlock(); - MySleep(0); - dictLocker.lock(); - } - } - } - - int basellm::FetchResponseLogits(int handleId, std::vector &logits) { - std::unique_lock dictLocker(this->dictLocker); - ResponseContext *context = responseContextDict.GetHandle(handleId); - if (context == nullptr) { - return -1; - } else { - while (true) { - if (context->resultTokenQueue.size() > 0) { - int ret = context->resultTokenQueue.front(); - context->resultTokenQueue.pop(); - if (!context->resultLogits.empty()) { - logits = *context->resultLogits.front(); - delete context->resultLogits.front(); - context->resultLogits.pop(); - } - return ret; - } else { - if (context->isEnding) { - responseContextDict.RemoveHandle(handleId); - return -1; - } - } - dictLocker.unlock(); - MySleep(0); - dictLocker.lock(); - } - } - } - - void basellm::AddPromptCache(const std::vector &inputTokens) { - std::unique_lock dictLocker(this->dictLocker); - auto cache = pastKVCacheManager.Get(inputTokens); - if (cache != nullptr && cache->inputToken.size() == inputTokens.size()) { - return; - } - Data inputIds, attentionMask, positionIds; - std::vector > pastKeyValues; - for (int i = 0; i < block_cnt; i++) { - pastKeyValues.push_back(std::make_pair(Data(this->dataType), Data(this->dataType))); - pastKeyValues.back().first.SetKVCache(); - pastKeyValues.back().second.SetKVCache(); - } - - int promptLen = inputTokens.size(), index = 0; - int add_special_tokens = false; - std::vector > fInputTokens; - fInputTokens.resize(1); - for (int i = 0; i < inputTokens.size(); i++) { - fInputTokens[0].push_back(inputTokens[i]); - } - FillLLMInputs(fInputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}}, - inputIds, attentionMask, positionIds); - ToDataType(attentionMask, this->dataType); - int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues); - pastKVCacheManager.Record(inputTokens, inputTokens.size(), &pastKeyValues); - } - - bool basellm::NeedAttentionMask(int qlen, int klen) { - return true; - } - - // 根据输入的tokens生成LLM推理的输入 - void basellm::FillLLMInputs(std::vector > &inputTokens, - const std::map ¶ms, - Data &inputIds, Data &attentionMask, Data &positionIds) { - inputIds.ToDevice(DataDevice::CPU); - attentionMask.ToDevice(DataDevice::CPU); - positionIds.ToDevice(DataDevice::CPU); - - int index = params.find("index")->second; - int promptLen = params.find("promptLen")->second; - - if (inputTokens[0].size() > 1) { - int seqLen = inputTokens[0].size(); - std::vector vpids = std::vector (seqLen, 0); - for (int i = 0; i < seqLen; i++) { - vpids[i] = promptLen - seqLen + i; - } - inputIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, inputTokens[0])); - positionIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, vpids)); - - if (NeedAttentionMask(seqLen, promptLen)) { - std::vector vmask = std::vector (seqLen * promptLen, 0); - for (int i = 0; i < seqLen; i++) { - vpids[i] = promptLen - seqLen + i; - for (int j = i + 1; j < seqLen; j++) { - vmask[i * promptLen + (promptLen - seqLen + j)] = 1; - } - } - attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, promptLen}, vmask)); - } else { - attentionMask = Data(); - } - } else { - inputIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, inputTokens[0])); - attentionMask = Data(); - positionIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, {(float) promptLen + index - 1})); - } - } - - // 根据输入的tokens生成LLM推理的输入 - void basellm::FillLLMInputsBatch(std::vector> &inputTokens, - const std::vector> ¶ms, fastllm::Data &inputIds, - fastllm::Data &attentionMask, fastllm::Data &positionIds) { - inputIds.ToDevice(DataDevice::CPU); - attentionMask.ToDevice(DataDevice::CPU); - positionIds.ToDevice(DataDevice::CPU); - - int batch = inputTokens.size(); - int index = params[0].find("index")->second; - if (index == 0) { - std::vector seqLens; - seqLens.resize(batch); - int maxLen = 0; - for (int i = 0; i < batch; i++) { - maxLen = std::max(maxLen, (int)inputTokens[i].size()); - seqLens[i] = (int)inputTokens[i].size(); - } - - std::vector ids = std::vector (batch * maxLen, 0); - std::vector vpids = std::vector (batch * maxLen, 0); - std::vector vmask = std::vector (batch * maxLen * maxLen, 0); - for (int i = 0; i < batch; i++) { - auto &tokens = inputTokens[i]; - int len = tokens.size(), base = maxLen - len; - for (int j = 0; j < len; j++) { - ids[i * maxLen + base + j] = tokens[j]; - } - for (int j = 0; j < len; j++) { - vpids[i * maxLen + base + j] = j; - } - - std::fill(vmask.data() + i * maxLen * maxLen, - vmask.data() + i * maxLen * maxLen + (maxLen - len) * maxLen, 1.0); - for (int j = maxLen - len; j < maxLen; j++) { - std::fill(vmask.data() + i * maxLen * maxLen + j * maxLen, - vmask.data() + i * maxLen * maxLen + j * maxLen + maxLen - len, 1.0); - } - for (int j = 0; j < len; j++) { - for (int k = j + 1; k < len; k++) { - vmask[i * maxLen * maxLen + (base + j) * maxLen + base + k] = 1; - } - } - } - - inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, maxLen}, ids)); - attentionMask.CopyFrom(Data(DataType::FLOAT32, {batch, maxLen, maxLen}, vmask)); - positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, maxLen}, vpids)); - } else { - std::vector pids = std::vector (batch); - std::vector fret; - for (int i = 0; i < batch; i++) { - fret.push_back(inputTokens[i][0]); - } - int maxLen = 0; - for (int i = 0; i < batch; i++) { - int promptLen = params[i].find("promptLen")->second; - maxLen = std::max(promptLen, maxLen); - pids[i] = promptLen + index - 1; - } - maxLen += index; - std::vector vmasks = std::vector (batch * maxLen, 0.0f); - for (int i = 0; i < batch; i++) { - int curLen = params[i].find("promptLen")->second + index; - for (int j = 0; j < maxLen - curLen; j++) { - vmasks[i * maxLen + j] = 1.0f; - } - } - - inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, fret)); - attentionMask.CopyFrom(Data(DataType::FLOAT32, {batch, 1, maxLen}, vmasks)); - positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, pids)); - } - } - - void basellm::SetAdapter(const std::string &name) { - if (weight.peftDict.find(name) == weight.peftDict.end()) { - ErrorInFastLLM("Can`t find adapter name: " + name); - } - adapterName = name; - } - - void basellm::DisableAdapter() { - adapterName = ""; - } - - bool basellm::SetSaveHistoryChat(bool save) { - if (this->model_type == "llama" || - this->model_type == "moe" || - this->model_type == "internlm" || - this->model_type == "qwen2_moe" || - this->model_type == "deepseek_v2" || - this->model_type == "qwen") { - this->saveHistoryChat = save; - return true; - } - return false; - } - - void basellm::SetDataType(DataType dataType) { - if (dataType == DataType::FLOAT32) { - - } else if (dataType == DataType::FLOAT16) { - AssertInFastLLM(this->model_struct == "chatglm" || - this->model_struct == "llama" || - this->model_struct == "graph", - this->model_struct + " doesn't support float16"); - } else { - ErrorInFastLLM("SetDataType Error: datatype should be float32 or float16"); - } - this->dataType = dataType; - } - - JinjaVar ChatMessagesToJinjaVar(const ChatMessages &messages) { - JinjaVar ret = {{"messages", fastllm::JinjaArray {}}}; - for (auto &message : messages) { - ret["messages"].arrayValue.push_back({ - {"role", message.first}, - {"content", message.second} - }); - } - ret["add_generation_prompt"] = fastllm::JinjaVar{1}; - return ret; - } - - std::string basellm::ApplyChatTemplate(const ChatMessages &messages) { - if (this->weight.tokenizer.chatTemplate == "") { - std::string ret = ""; - std::string user = ""; - int round = 0; - for (auto &message : messages) { - if (message.first == "user") { - user = message.second; - } else if (message.first == "assistant") { - ret = MakeHistory(ret, round++, user, message.second); - } - } - ret = MakeInput(ret, round, user); - return ret; - } - return ApplyChatTemplate(ChatMessagesToJinjaVar(messages)); - } - - std::vector basellm::ApplyChatTemplateToTokens(const ChatMessages &messages) { - auto prompt = this->ApplyChatTemplate(messages); - auto input = this->weight.tokenizer.Encode(prompt); - std::vector tokens; - for (int i = 0; i < input.Count(0); i++) { - tokens.push_back(((float *) input.cpuData)[i]); - } - return tokens; - } - - std::string basellm::ApplyChatTemplate(const JinjaVar &var) { - AssertInFastLLM(this->weight.tokenizer.chatTemplate != "", - "ApplyChatTemplate error: model doesn't has chat_template."); - JinjaVar local = var; - for (auto &it : this->weight.tokenizer.tokenizerConfig.object_items()) { - if (it.first != "messages" && it.second.is_string()) { - local[it.first] = it.second.string_value(); - } - } - JinjaTemplate temp = JinjaTemplate(this->weight.tokenizer.chatTemplate); - return temp.Apply(local); - } - - std::vector basellm::ApplyChatTemplateToTokens(const JinjaVar &var) { - auto prompt = this->ApplyChatTemplate(var); - auto input = this->weight.tokenizer.Encode(prompt); - std::vector tokens; - for (int i = 0; i < input.Count(0); i++) { - tokens.push_back(((float *) input.cpuData)[i]); - } - return tokens; - } -} +// +// Created by huangyuyang on 6/25/23. +// + +#include "basellm.h" +#include "utils.h" +#include +#include + +#ifdef USE_CUDA +#include "fastllm-cuda.cuh" +#endif + +namespace fastllm { + int ResponseContextDict::CreateHandle() { + locker.lock(); + int newId = 0; + while (dicts.find(newId) != dicts.end()) { + newId++; + } + dicts[newId] = new ResponseContext(); + locker.unlock(); + return newId; + } + + ResponseContext *ResponseContextDict::GetHandle(int handleId) { + locker.lock(); + ResponseContext *ret = dicts.find(handleId) != dicts.end() ? dicts[handleId] : nullptr; + locker.unlock(); + return ret; + } + + void ResponseContextDict::RemoveHandle(int handleId) { + locker.lock(); + if (dicts.find(handleId) != dicts.end()) { + delete dicts[handleId]; + dicts.erase(handleId); + } + locker.unlock(); + } + + void ResponseContext::Init(int blocks, DataType dataType) { + pastKeyValues.clear(); + for (int i = 0; i < blocks; i++) { + pastKeyValues.push_back(std::make_pair(Data(dataType), + Data(dataType))); + pastKeyValues.back().first.SetKVCache(); + pastKeyValues.back().second.SetKVCache(); + } + intParams.clear(); + currentTokens.clear(); + while (resultTokenQueue.size() > 0){ + resultTokenQueue.pop(); + } + isEnding = false; + preTokens = 0; + } + + PastKVCacheMemory::PastKVCacheMemory(const std::vector &inputToken, int tokens, long long flushTime, std::vector > *kv) { + this->inputToken = inputToken; + this->tokens = tokens; + this->flushTime = flushTime; + this->recordTimes = 1; + auto dataType = (*kv)[0].first.dataType; + for (int i = 0; i < kv->size(); i++) { + this->kv.push_back(std::make_pair(Data(dataType), Data(dataType))); + } + for (int i = 0; i < kv->size(); i++) { + this->kv[i].first.CopyFrom((*kv)[i].first); + this->kv[i].second.CopyFrom((*kv)[i].second); + } + } + + void PastKVCacheManager::SetMaxRecordNum(int maxRecordNum) { + std::lock_guard lock(this->locker); + this->maxRecordNum = maxRecordNum; + } + + void PastKVCacheManager::Record(const std::vector &inputToken, int tokens, std::vector > *kv) { + std::lock_guard lock(this->locker); + if (this->memorys.find(inputToken) != this->memorys.end()) { + this->memorys[inputToken]->recordTimes++; + this->memorys[inputToken]->flushTime = ++flushTime; + return; + } + + if (this->memorys.size() >= this->maxRecordNum) { + std::vector eraseToken; + long long minFlushTime = (1LL << 60); + for (auto &it : this->memorys) { + if (it.second->flushTime < minFlushTime) { + minFlushTime = it.second->flushTime; + eraseToken = it.first; + } + } + delete this->memorys[eraseToken]; + this->memorys.erase(this->memorys.find(eraseToken)); + } + + this->memorys[inputToken] = new PastKVCacheMemory(inputToken, tokens, ++flushTime, kv); + } + + void PastKVCacheManager::Remove(const std::vector &inputToken) { + std::lock_guard lock(this->locker); + if (this->memorys.find(inputToken) != this->memorys.end()) { + if ((--this->memorys[inputToken]->recordTimes) <= 0) { + delete this->memorys[inputToken]; + this->memorys.erase(this->memorys.find(inputToken)); + } + } + } + + PastKVCacheMemory *PastKVCacheManager::Get(const std::vector &inputToken) { + std::lock_guard lock(this->locker); + std::vector maxToken; + for (auto &it : this->memorys) { + const std::vector &cur = it.first; + if (cur.size() > maxToken.size() && cur.size() <= inputToken.size()) { + bool match = true; + for (int i = 0; i < cur.size(); i++) { + if (inputToken[i] != cur[i]) { + match = false; + break; + } + } + if (match) { + maxToken = cur; + } + } + } + if (maxToken.size() == 0) { + return nullptr; + } + this->memorys[maxToken]->flushTime = ++this->flushTime; + return this->memorys[maxToken]; + } + + void PastKVCacheManager::Unlock() { + locker.unlock(); + } + + basellm::~basellm() { + dictLocker.lock(); + this->isFree = true; + dictLocker.unlock(); + dictCV.notify_all(); + this->weight.ReleaseWeight(); + } + + std::map > > + basellm::GetTensorMap(const std::vector &tensorNames) { + std::map > > ret; + for (auto &name : tensorNames) { + WeightType weightType = this->weight.GetWeightType(name); + DataType dataType = DataType::DATA_AUTO_NONE; + if (weightType == WeightType::LINEAR) { + dataType = DataType::DATA_AUTO_LINEAR; + } else if (weightType == WeightType::EMBEDDING) { + dataType = DataType::DATA_AUTO_EMBEDDING; + } + ret[name].push_back(std::make_pair(name, dataType)); + } + return ret; + } + + std::string basellm::Response(const std::string &oriInput, RuntimeResult retCb, + const fastllm::GenerationConfig &generationConfig) { + std::string input = oriInput; + if (this->saveHistoryChat) { + if (lastKeyValues != nullptr) { + if (input.size() < lastPrompt.size() || (input.substr(0, lastPrompt.size()) != lastPrompt)) { + lastPrompt = ""; + lastPromptTokens = 0; + delete lastKeyValues; + lastKeyValues = nullptr; + } else { + input = input.substr(lastPrompt.size()); + } + } + } else { + lastPrompt = ""; + lastPromptTokens = 0; + delete lastKeyValues; + lastKeyValues = nullptr; + } + + //printf("lastPrompt = %s\n", lastPrompt.c_str()); + //printf("input = %s\n", input.c_str()); + +#ifdef USE_CUDA + FastllmCudaClearBigBuffer(); +#endif + std::string prompt = input; +#ifdef PY_API + size_t pos = input.rfind("time_stamp:"); + prompt = (generationConfig.enable_hash_id && pos != -1) ? input.substr(0, pos) : input; + size_t hash_id = std::hash{}(input); +#endif + Data inputIds, attentionMask, positionIds; + + Data inputTokenData = this->weight.tokenizer.Encode(prompt); + std::vector > inputTokens; + inputTokens.resize(1); + for (int i = 0; i < inputTokenData.Count(0); i++) { + inputTokens[0].push_back(((float *) inputTokenData.cpuData)[i]); + } + + if (lastKeyValues == nullptr) { + lastKeyValues = new std::vector >(); + for (int i = 0; i < block_cnt; i++) { + lastKeyValues->push_back(std::make_pair(Data(this->dataType), Data(this->dataType))); + lastKeyValues->back().first.SetKVCache(); + lastKeyValues->back().second.SetKVCache(); + } + } + + std::vector > &pastKeyValues = (*lastKeyValues); + std::string retString = ""; + std::vector results; + LastTokensManager tokens(1, generationConfig.last_n); + int promptLen = lastPromptTokens + inputTokens[0].size(), index = 0; + int add_special_tokens = generationConfig.add_special_tokens? 1: 0; + FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}}, + inputIds, attentionMask, positionIds); + ToDataType(attentionMask, this->dataType); + while (true) { + auto st = std::chrono::system_clock::now(); + int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens); + tokens.units[0].Push(ret); + if (ret == eos_token_id + || generationConfig.stop_token_ids.find(ret) != generationConfig.stop_token_ids.end() + || eos_token_ids.find(ret) != eos_token_ids.end()) { + break; + } + + results.push_back(ret); + std::string curString = weight.tokenizer.Decode( + Data(DataType::FLOAT32, {(int) results.size()}, results)).c_str(); + retString += curString; + if (retCb) +#ifdef PY_API + { + if (generationConfig.enable_hash_id) { + std::stringstream ss; + ss << retString << "hash_id:"< {(float)ret}; + FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}}, + inputIds, attentionMask, positionIds); + ToDataType(attentionMask, this->dataType); + if (index == generationConfig.output_token_limit) { + break; + } + // printf("len = %d, spend %f s.\n", len, GetSpan(st, std::chrono::system_clock::now())); + } + if (retCb) +#ifdef PY_API + { + if(generationConfig.enable_hash_id){ + std::stringstream ss; + ss << retString << "hash_id:"< &inputs, std::vector &outputs, + RuntimeResultBatch retCb, const fastllm::GenerationConfig &generationConfig) { +#ifdef USE_CUDA + FastllmCudaClearBigBuffer(); +#endif +#ifdef PY_API + std::vector prompts; + std::vector < size_t > hash_ids; + for (auto _input: inputs){ + size_t hash_id = std::hash{}(_input); + hash_ids.push_back(hash_id); + + size_t pos = _input.rfind("time_stamp:"); + std::string prompt = (generationConfig.enable_hash_id && pos != -1) ? _input.substr(0, pos) : _input; + prompts.push_back(prompt); + } +#else + std::vector prompts = inputs; +#endif + // 1. first + Data inputIds, attentionMask, positionIds; + + int batch = prompts.size(); + outputs.clear(); + outputs.resize(batch, ""); + + std::vector > inputTokens; + inputTokens.resize(batch); + + for (int i = 0; i < batch; i++) { + Data now = this->weight.tokenizer.Encode(prompts[i]); + for (int j = 0; j < now.Count(0); j++) { + inputTokens[i].push_back(((float *) now.cpuData)[j]); + } + } + + std::vector > pastKeyValues; + for (int i = 0; i < block_cnt; i++) { + pastKeyValues.push_back(std::make_pair(Data(dataType), + Data(dataType))); + pastKeyValues.back().first.SetKVCache(); + pastKeyValues.back().second.SetKVCache(); + } + + std::vector > params; + params.resize(batch); + for (int i = 0; i < batch; i++) { + params[i]["promptLen"] = (int)inputTokens[i].size(); + } + params[0]["index"] = 0; + int index = 0; + params[0]["add_special_tokens"] = generationConfig.add_special_tokens? 1: 0; + + LastTokensManager tokensManager (batch, generationConfig.last_n); + std::vector isEnding = std::vector (batch, false); + FillLLMInputsBatch(inputTokens, params, inputIds, attentionMask, positionIds); + ToDataType(attentionMask, this->dataType); + while (true) { + auto st = std::chrono::system_clock::now(); +// ClearProfiler(); + std::vector ret = ForwardBatch(batch, inputIds, attentionMask, positionIds, pastKeyValues, + generationConfig, tokensManager); +// PrintProfiler(); + for (int i = 0; i < batch; i++) { + tokensManager.units[i].Push(ret[i]); + } + std::vector fret; + std::vector results; + int endingCount = 0; + std::vector curStrings; + for (int i = 0; i < batch; i++) { + fret.push_back(ret[i]); + inputTokens[i] = std::vector {(float)ret[i]}; + if (ret[i] == eos_token_id || eos_token_ids.find(ret[i]) != eos_token_ids.end()) { + isEnding[i] = true; + } else { + auto itStopTk = generationConfig.stop_token_ids.find(ret[i]); + if (itStopTk != generationConfig.stop_token_ids.end()) { + isEnding[i] = true; + } + } + if (isEnding[i]) { + curStrings.push_back(""); + endingCount++; + continue; + } + results.push_back(ret[i]); + std::string curString = weight.tokenizer.Decode( + Data(DataType::FLOAT32, {(int) results.size()}, results)).c_str(); + outputs[i] += curString; + curStrings.push_back(curString); + results.clear(); + } + + if (endingCount == batch) { + break; + } + if (retCb) +#ifdef PY_API + { + if (generationConfig.enable_hash_id) { + std::vector rtnStrings; + for (size_t i=0; i rtnStrings; + for (size_t i=0; idataType); + // printf("len = %d, spend %f s.\n", len, GetSpan(st, std::chrono::system_clock::now())); + + if (index == generationConfig.output_token_limit) { + break; + } + } + if (retCb) +#ifdef PY_API + { + if (generationConfig.enable_hash_id) { + std::vector rtnStrings; + for (size_t i=0; i rtnStrings; + for (size_t i=0; i basellm::ForwardBatch(int batch, const fastllm::Data &inputIds, const fastllm::Data &attentionMask, + const fastllm::Data &positionIds, + std::vector> &pastKeyValues, + const fastllm::GenerationConfig &generationConfig, + const fastllm::LastTokensManager &lastTokens, + std::vector *> *retLogits) { + printf("Unsupport forward batch.\n"); + exit(0); + } + + std::vector basellm::ForwardBatch(int batch, const fastllm::Data &inputIds, + const std::vector &attentionMask, + const std::vector &positionIds, const std::vector &seqLens, + std::vector> &pastKeyValues, + const std::vector &generationConfigs, + const fastllm::LastTokensManager &lastTokens, + std::vector *> *logits) { + std::vector ret; + int cur = 0; + for (int i = 0; i < batch; i++) { + std::vector > curKV; + curKV.resize(this->block_cnt); + for (int j = 0; j < this->block_cnt; j++) { + Mul(*pastKeyValues[i * this->block_cnt + j].first, 1.0, curKV[j].first); + Mul(*pastKeyValues[i * this->block_cnt + j].second, 1.0, curKV[j].second); + } + Data curInput; + Split(inputIds, 1, cur, cur + seqLens[i], curInput); + LastTokensManager curTokens; + curTokens.units.push_back(lastTokens.units[i]); + ret.push_back(this->Forward(curInput, *attentionMask[i], *positionIds[i], curKV, generationConfigs[i], curTokens)); + for (int j = 0; j < this->block_cnt; j++) { + Mul(curKV[j].first, 1.0, *pastKeyValues[i * this->block_cnt + j].first); + Mul(curKV[j].second, 1.0, *pastKeyValues[i * this->block_cnt + j].second); + } + } + return ret; + } + + int basellm::LaunchResponseTokens(const std::vector &inputTokens, + const fastllm::GenerationConfig &generationConfig) { + mainLoopLocker.lock(); + if (mainLoop == nullptr) { + if (mainLoop == nullptr) { + mainLoop = new std::thread([](basellm *model) { + long long kvCacheLimit = 16LL << 30; +#ifdef USE_CUDA + auto freeSizes = FastllmCudaGetFreeSizes(); + kvCacheLimit = 0; + for (long long i : freeSizes) { + kvCacheLimit += std::max(0LL, i - (2LL << 30)); + } +#endif + if (model->kvCacheLimit > 0) { + kvCacheLimit = model->kvCacheLimit; + } + + int unitSize = (model->dataType == DataType::FLOAT32 ? 4 : 2); + int maxTotalLens = kvCacheLimit / (model->elementsInKVCachePerToken * unitSize); + if (model->elementsInKVCachePerToken <= 0) { + maxTotalLens = kvCacheLimit / 1024 / 1024; + } + if (model->tokensLimit > 0) { + maxTotalLens = model->tokensLimit; + } + + int maxBatch = std::max(1, std::min(512, maxTotalLens / 128)); + if (model->maxBatch > 0) { + maxBatch = model->maxBatch; + } + + model->tokensLimit = maxTotalLens; + if (model->verbose) { + printf("Fastllm KV Cache Limit: %f MB.\n", (double)kvCacheLimit / 1024 / 1024); + printf("Fastllm KV Cache Token limit: %d tokens.\n", maxTotalLens); + printf("Fastllm Batch limit: %d.\n", maxBatch); + } + + auto lastRecordTime = std::chrono::system_clock::now(); + long long genTokens = 0; + while (true) { + if (model->isFree) { + break; + } + std::vector attentionMasks; + std::vector positionIds; + std::vector > pastKeyValues; + std::vector ids; + std::vector seqLens; + std::vector handles; + std::vector generationConfigs; + LastTokensManager tokensManager; + std::vector * > logits; + + std::unique_lock dictLocker(model->dictLocker); + + int limit = maxTotalLens; + int promptLimit = limit * 2 / 3; + + int lenSum = 0; + for (auto &it: model->responseContextDict.dicts) { + if (it.second->pastKeyValues[0].first.expansionDims.size() > 0) { + lenSum += it.second->pastKeyValues[0].first.expansionDims[1]; + } + } + + for (int isPrompt = 1; isPrompt >= 0; isPrompt--) { + int cnt = 0; + if (isPrompt == 0 && seqLens.size() > 0) { + continue; + } + if (lenSum >= promptLimit && isPrompt) { + continue; + } + + for (auto &it: model->responseContextDict.dicts) { + if (it.second->isEnding) { + continue; + } + if (isPrompt && it.second->preTokens != 0) { + continue; + } + if (!isPrompt && it.second->preTokens == 0) { + continue; + } + + int outputLimit = it.second->generationConfig.output_token_limit; + outputLimit = (outputLimit < 0 ? 128 : outputLimit); + if (isPrompt && lenSum + it.second->currentTokens.size() > promptLimit) { + continue; + } + + if (!isPrompt) { + if (it.second->pastKeyValues[0].first.expansionDims[1] == it.second->pastKeyValues[0].first.dims[1]) { + int sur = it.second->generationConfig.output_token_limit - it.second->curTokens; + int predictLen = 256; + if (sur > 0) { + predictLen = std::min(predictLen, ((sur - 1) / 128 + 1) * 128); + } + if (lenSum + predictLen > limit) { + continue; + } + lenSum += predictLen; + } + } + + generationConfigs.push_back(it.second->generationConfig); + if (it.second->generationConfig.output_logits) { + it.second->resultLogits.push(new std::vector()); + logits.push_back(it.second->resultLogits.back()); + } else { + logits.push_back(nullptr); + } + + tokensManager.units.push_back(it.second->tokens); + handles.push_back(it.first); + + if (it.second->preTokens == 0) { + it.second->intParams["add_special_tokens"] = it.second->cacheLen > 0 ? false : it.second->generationConfig.add_special_tokens; + it.second->intParams["promptLen"] = it.second->cacheLen + it.second->currentTokens.size(); + it.second->intParams["index"] = 0; + } else { + it.second->intParams["index"]++; + } + Data inputIds, attentionMask, curPositionIds; + std::vector > tokens; + tokens.resize(1); + for (int i: it.second->currentTokens) { + tokens[0].push_back(i); + } + model->FillLLMInputs(tokens, it.second->intParams, inputIds, attentionMask, curPositionIds); + ToDataType(attentionMask, model->dataType); + + seqLens.push_back(inputIds.Count(0)); + for (int i = 0; i < inputIds.Count(0); i++) { + ids.push_back(((float *) inputIds.cpuData)[i]); + } + if (attentionMask.dims.size() == 0) { + attentionMasks.push_back(nullptr); + } else { + attentionMasks.push_back(new Data()); + attentionMask.ToDevice(DataDevice::CPU); + attentionMasks.back()->CopyFrom(attentionMask); + } + if (curPositionIds.dims.size() == 0) { + positionIds.push_back(nullptr); + } else { + positionIds.push_back(new Data()); + positionIds.back()->CopyFrom(curPositionIds); + } + it.second->preTokens += seqLens.back(); + for (int i = 0; i < model->block_cnt; i++) { + pastKeyValues.push_back(std::make_pair(&it.second->pastKeyValues[i].first, + &it.second->pastKeyValues[i].second)); + } + if (isPrompt) { + cnt += it.second->currentTokens.size(); + + if (cnt > 1024) { + break; + } + // break; + } + + if (seqLens.size() >= maxBatch || lenSum + seqLens.size() * 128 > limit) { + break; + } + } + } + if (seqLens.size() > 0) { + std::vector > *pastKeyValue1; + if (seqLens.size() == 1) { + pastKeyValue1 = &model->responseContextDict.dicts[handles[0]]->pastKeyValues; + } + dictLocker.unlock(); +#ifdef USE_CUDA + FastllmCudaClearBigBuffer(); +#endif + Data inputIds = Data(DataType::FLOAT32, {1, (int) ids.size()}, ids); + std::vector ret; +auto st = std::chrono::system_clock::now(); +//ClearProfiler(); + if (seqLens.size() > 1) { + ret = model->ForwardBatch(seqLens.size(), inputIds, attentionMasks, + positionIds, seqLens, pastKeyValues, generationConfigs, + tokensManager, &logits); + } else { + if (seqLens[0] > 8192) { + int len = seqLens[0]; + int first = 8192, part = 2048; + for (int st = 0; st < len; ) { + int curLen = std::min(st == 0 ? first : part, len - st); + Data curInput, curPositionIds; + Split(inputIds, 1, st, st + curLen, curInput); + Split(*positionIds[0], 1, st, st + curLen, curPositionIds); + + ret = std::vector {model->Forward(curInput, Data(), curPositionIds, + *pastKeyValue1, generationConfigs[0], tokensManager, logits[0])}; + st += curLen; + } + } else { + ret = std::vector {model->Forward(inputIds, + attentionMasks[0] == nullptr ? Data() : *attentionMasks[0], + *positionIds[0], + *pastKeyValue1, generationConfigs[0], tokensManager, logits[0])}; + } + } + + +//PrintProfiler(); +/*int total = 0; +for (int i : seqLens) total += i; +float spend = GetSpan(st, std::chrono::system_clock::now()); +printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)total / spend); +*/ + if (model->verbose) { + genTokens += seqLens.size(); + auto nowTime = std::chrono::system_clock::now(); + float spend = GetSpan(lastRecordTime, nowTime); + if (spend > 1) { + printf("Current batch: %d, Speed: %f tokens / s.\n", (int)seqLens.size(), (float)genTokens / spend); + lastRecordTime = nowTime; + genTokens = 0; + } + } + + dictLocker.lock(); + for (int i = 0; i < handles.size(); i++) { + auto &it = *model->responseContextDict.dicts.find(handles[i]); + int curRet = ret[i]; + if (curRet == model->eos_token_id || model->eos_token_ids.find(curRet) != model->eos_token_ids.end()) { + it.second->isEnding = true; + } else { + auto itStopTk = it.second->generationConfig.stop_token_ids.find(curRet); + if (itStopTk != it.second->generationConfig.stop_token_ids.end()) { + it.second->isEnding = true; + } + } + if (it.second->isEnding == false) { + it.second->currentTokens = std::vector{curRet}; + it.second->resultTokenQueue.push(curRet); + it.second->tokens.Push(curRet); + it.second->curTokens++; + if (it.second->curTokens == it.second->generationConfig.output_token_limit) { + it.second->isEnding = true; + } + } + } + } + + for (int i = 0; i < attentionMasks.size(); i++) { + delete attentionMasks[i]; + } + for (int i = 0; i < positionIds.size(); i++) { + delete positionIds[i]; + } + + if (seqLens.size() == 0) { + model->dictCV.wait(dictLocker); + } + } + }, this); + } + } + mainLoopLocker.unlock(); + + dictLocker.lock(); + int handleId = responseContextDict.CreateHandle(); + ResponseContext *context = responseContextDict.GetHandle(handleId); + context->Init(this->block_cnt, this->dataType); + context->currentTokens = inputTokens; + context->generationConfig = generationConfig; + context->tokens = LastTokensUnit(generationConfig.last_n); + + auto cache = pastKVCacheManager.Get(inputTokens); + if (cache != nullptr) { + for (int i = 0; i < this->block_cnt; i++) { + context->pastKeyValues[i].first.CopyFrom(cache->kv[i].first); + context->pastKeyValues[i].second.CopyFrom(cache->kv[i].second); + } + context->currentTokens.erase(context->currentTokens.begin(), context->currentTokens.begin() + cache->inputToken.size()); + context->cacheLen = cache->inputToken.size(); + } + + dictLocker.unlock(); + dictCV.notify_one(); + return handleId; + } + + bool basellm::CanFetchResponse(int handleId) { + std::unique_lock dictLocker(this->dictLocker); + ResponseContext *context = responseContextDict.GetHandle(handleId); + if (context == nullptr) { + return true; + } else { + return (context->resultTokenQueue.size() > 0 || context->isEnding); + } + } + + int basellm::FetchResponseTokens(int handleId) { + std::unique_lock dictLocker(this->dictLocker); + ResponseContext *context = responseContextDict.GetHandle(handleId); + if (context == nullptr) { + return -1; + } else { + while (true) { + if (context->resultTokenQueue.size() > 0) { + int ret = context->resultTokenQueue.front(); + context->resultTokenQueue.pop(); + return ret; + } else { + if (context->isEnding) { + responseContextDict.RemoveHandle(handleId); + return -1; + } + } + dictLocker.unlock(); + MySleep(0); + dictLocker.lock(); + } + } + } + + int basellm::FetchResponseLogits(int handleId, std::vector &logits) { + std::unique_lock dictLocker(this->dictLocker); + ResponseContext *context = responseContextDict.GetHandle(handleId); + if (context == nullptr) { + return -1; + } else { + while (true) { + if (context->resultTokenQueue.size() > 0) { + int ret = context->resultTokenQueue.front(); + context->resultTokenQueue.pop(); + if (!context->resultLogits.empty()) { + logits = *context->resultLogits.front(); + delete context->resultLogits.front(); + context->resultLogits.pop(); + } + return ret; + } else { + if (context->isEnding) { + responseContextDict.RemoveHandle(handleId); + return -1; + } + } + dictLocker.unlock(); + MySleep(0); + dictLocker.lock(); + } + } + } + + void basellm::AddPromptCache(const std::vector &inputTokens) { + std::unique_lock dictLocker(this->dictLocker); + auto cache = pastKVCacheManager.Get(inputTokens); + if (cache != nullptr && cache->inputToken.size() == inputTokens.size()) { + return; + } + Data inputIds, attentionMask, positionIds; + std::vector > pastKeyValues; + for (int i = 0; i < block_cnt; i++) { + pastKeyValues.push_back(std::make_pair(Data(this->dataType), Data(this->dataType))); + pastKeyValues.back().first.SetKVCache(); + pastKeyValues.back().second.SetKVCache(); + } + + int promptLen = inputTokens.size(), index = 0; + int add_special_tokens = false; + std::vector > fInputTokens; + fInputTokens.resize(1); + for (int i = 0; i < inputTokens.size(); i++) { + fInputTokens[0].push_back(inputTokens[i]); + } + FillLLMInputs(fInputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}}, + inputIds, attentionMask, positionIds); + ToDataType(attentionMask, this->dataType); + int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues); + pastKVCacheManager.Record(inputTokens, inputTokens.size(), &pastKeyValues); + } + + bool basellm::NeedAttentionMask(int qlen, int klen) { + return true; + } + + // 根据输入的tokens生成LLM推理的输入 + void basellm::FillLLMInputs(std::vector > &inputTokens, + const std::map ¶ms, + Data &inputIds, Data &attentionMask, Data &positionIds) { + inputIds.ToDevice(DataDevice::CPU); + attentionMask.ToDevice(DataDevice::CPU); + positionIds.ToDevice(DataDevice::CPU); + + int index = params.find("index")->second; + int promptLen = params.find("promptLen")->second; + + if (inputTokens[0].size() > 1) { + int seqLen = inputTokens[0].size(); + std::vector vpids = std::vector (seqLen, 0); + for (int i = 0; i < seqLen; i++) { + vpids[i] = promptLen - seqLen + i; + } + inputIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, inputTokens[0])); + positionIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, vpids)); + + if (NeedAttentionMask(seqLen, promptLen)) { + std::vector vmask = std::vector (seqLen * promptLen, 0); + for (int i = 0; i < seqLen; i++) { + vpids[i] = promptLen - seqLen + i; + for (int j = i + 1; j < seqLen; j++) { + vmask[i * promptLen + (promptLen - seqLen + j)] = 1; + } + } + attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, promptLen}, vmask)); + } else { + attentionMask = Data(); + } + } else { + inputIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, inputTokens[0])); + attentionMask = Data(); + positionIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, {(float) promptLen + index - 1})); + } + } + + // 根据输入的tokens生成LLM推理的输入 + void basellm::FillLLMInputsBatch(std::vector> &inputTokens, + const std::vector> ¶ms, fastllm::Data &inputIds, + fastllm::Data &attentionMask, fastllm::Data &positionIds) { + inputIds.ToDevice(DataDevice::CPU); + attentionMask.ToDevice(DataDevice::CPU); + positionIds.ToDevice(DataDevice::CPU); + + int batch = inputTokens.size(); + int index = params[0].find("index")->second; + if (index == 0) { + std::vector seqLens; + seqLens.resize(batch); + int maxLen = 0; + for (int i = 0; i < batch; i++) { + maxLen = std::max(maxLen, (int)inputTokens[i].size()); + seqLens[i] = (int)inputTokens[i].size(); + } + + std::vector ids = std::vector (batch * maxLen, 0); + std::vector vpids = std::vector (batch * maxLen, 0); + std::vector vmask = std::vector (batch * maxLen * maxLen, 0); + for (int i = 0; i < batch; i++) { + auto &tokens = inputTokens[i]; + int len = tokens.size(), base = maxLen - len; + for (int j = 0; j < len; j++) { + ids[i * maxLen + base + j] = tokens[j]; + } + for (int j = 0; j < len; j++) { + vpids[i * maxLen + base + j] = j; + } + + std::fill(vmask.data() + i * maxLen * maxLen, + vmask.data() + i * maxLen * maxLen + (maxLen - len) * maxLen, 1.0); + for (int j = maxLen - len; j < maxLen; j++) { + std::fill(vmask.data() + i * maxLen * maxLen + j * maxLen, + vmask.data() + i * maxLen * maxLen + j * maxLen + maxLen - len, 1.0); + } + for (int j = 0; j < len; j++) { + for (int k = j + 1; k < len; k++) { + vmask[i * maxLen * maxLen + (base + j) * maxLen + base + k] = 1; + } + } + } + + inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, maxLen}, ids)); + attentionMask.CopyFrom(Data(DataType::FLOAT32, {batch, maxLen, maxLen}, vmask)); + positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, maxLen}, vpids)); + } else { + std::vector pids = std::vector (batch); + std::vector fret; + for (int i = 0; i < batch; i++) { + fret.push_back(inputTokens[i][0]); + } + int maxLen = 0; + for (int i = 0; i < batch; i++) { + int promptLen = params[i].find("promptLen")->second; + maxLen = std::max(promptLen, maxLen); + pids[i] = promptLen + index - 1; + } + maxLen += index; + std::vector vmasks = std::vector (batch * maxLen, 0.0f); + for (int i = 0; i < batch; i++) { + int curLen = params[i].find("promptLen")->second + index; + for (int j = 0; j < maxLen - curLen; j++) { + vmasks[i * maxLen + j] = 1.0f; + } + } + + inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, fret)); + attentionMask.CopyFrom(Data(DataType::FLOAT32, {batch, 1, maxLen}, vmasks)); + positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, pids)); + } + } + + void basellm::SetAdapter(const std::string &name) { + if (weight.peftDict.find(name) == weight.peftDict.end()) { + ErrorInFastLLM("Can`t find adapter name: " + name); + } + adapterName = name; + } + + void basellm::DisableAdapter() { + adapterName = ""; + } + + bool basellm::SetSaveHistoryChat(bool save) { + if (this->model_type == "llama" || + this->model_type == "moe" || + this->model_type == "internlm" || + this->model_type == "qwen2_moe" || + this->model_type == "deepseek_v2" || + this->model_type == "qwen") { + this->saveHistoryChat = save; + return true; + } + return false; + } + + void basellm::SetDataType(DataType dataType) { + if (dataType == DataType::FLOAT32) { + + } else if (dataType == DataType::FLOAT16) { + AssertInFastLLM(this->model_struct == "chatglm" || + this->model_struct == "llama" || + this->model_struct == "graph", + this->model_struct + " doesn't support float16"); + } else { + ErrorInFastLLM("SetDataType Error: datatype should be float32 or float16"); + } + this->dataType = dataType; + } + + JinjaVar ChatMessagesToJinjaVar(const ChatMessages &messages) { + JinjaVar ret = {{"messages", fastllm::JinjaArray {}}}; + for (auto &message : messages) { + ret["messages"].arrayValue.push_back({ + {"role", message.first}, + {"content", message.second} + }); + } + ret["add_generation_prompt"] = fastllm::JinjaVar{1}; + return ret; + } + + std::string basellm::ApplyChatTemplate(const ChatMessages &messages) { + if (this->weight.tokenizer.chatTemplate == "") { + std::string ret = ""; + std::string user = ""; + int round = 0; + for (auto &message : messages) { + if (message.first == "user") { + user = message.second; + } else if (message.first == "assistant") { + ret = MakeHistory(ret, round++, user, message.second); + } + } + ret = MakeInput(ret, round, user); + return ret; + } + return ApplyChatTemplate(ChatMessagesToJinjaVar(messages)); + } + + std::vector basellm::ApplyChatTemplateToTokens(const ChatMessages &messages) { + auto prompt = this->ApplyChatTemplate(messages); + auto input = this->weight.tokenizer.Encode(prompt); + std::vector tokens; + for (int i = 0; i < input.Count(0); i++) { + tokens.push_back(((float *) input.cpuData)[i]); + } + return tokens; + } + + std::string basellm::ApplyChatTemplate(const JinjaVar &var) { + AssertInFastLLM(this->weight.tokenizer.chatTemplate != "", + "ApplyChatTemplate error: model doesn't has chat_template."); + JinjaVar local = var; + for (auto &it : this->weight.tokenizer.tokenizerConfig.object_items()) { + if (it.first != "messages" && it.second.is_string()) { + local[it.first] = it.second.string_value(); + } else if (it.first.find_last_of("_token") != std::string::npos && it.second.is_object()) { + local[it.first] = it.second["content"].string_value(); + } + } + JinjaTemplate temp = JinjaTemplate(this->weight.tokenizer.chatTemplate); + return temp.Apply(local); + } + + std::vector basellm::ApplyChatTemplateToTokens(const JinjaVar &var) { + auto prompt = this->ApplyChatTemplate(var); + auto input = this->weight.tokenizer.Encode(prompt); + std::vector tokens; + for (int i = 0; i < input.Count(0); i++) { + tokens.push_back(((float *) input.cpuData)[i]); + } + return tokens; + } +} diff --git a/src/template.cpp b/src/template.cpp index f7bf4864..18df1cf9 100644 --- a/src/template.cpp +++ b/src/template.cpp @@ -197,6 +197,8 @@ namespace fastllm { type = JinjaBlockType::JinjaBlockSet; } else if (tokens[0].type == JinjaToken::JinjaTokenIf) { type = JinjaBlockType::JinjaBlockIf; + } else if (tokens[0].type == JinjaToken::JinjaTokenElseIf) { + type = JinjaBlockType::JinjaBlockElseIf; } else if (tokens[0].type == JinjaToken::JinjaTokenElse) { type = JinjaBlockType::JinjaBlockElse; } else if (tokens[0].type == JinjaToken::JinjaTokenEndif) { @@ -218,6 +220,8 @@ namespace fastllm { if (a.type == JinjaVar::JinjaString && b.type == JinjaVar::JinjaString) { return a.stringValue + b.stringValue; } + } else if (op == JinjaToken::JinjaTokenIn) { + return b.dictValue.find(a.stringValue) != b.dictValue.end(); } else if (op == JinjaToken::JinjaTokenEqual) { if (a.type != b.type) { return false; @@ -325,6 +329,7 @@ namespace fastllm { tokens[i].type == JinjaToken::JinjaTokenDiv || tokens[i].type == JinjaToken::JinjaTokenEqual || tokens[i].type == JinjaToken::JinjaTokenNotEqual || + tokens[i].type == JinjaToken::JinjaTokenIn || tokens[i].type == JinjaToken::JinjaTokenAnd || tokens[i].type == JinjaToken::JinjaTokenOr || tokens[i].type == JinjaToken::JinjaTokenFliter) { @@ -344,7 +349,10 @@ namespace fastllm { std::vector vars; for (auto &it : suffixExp) { if (it.type == JinjaToken::JinjaTokenID) { - vars.push_back(JinjaVar(JinjaVar::JinjaNone, it.value)); + if (it.value == "defined") + vars.push_back(local); + else + vars.push_back(JinjaVar(JinjaVar::JinjaNone, it.value)); } else if (it.type == JinjaToken::JinjaTokenBOOL) { vars.push_back(JinjaVar(it.value)); } else if (it.type == JinjaToken::JinjaTokenSTRING) { @@ -392,6 +400,7 @@ namespace fastllm { it.type == JinjaToken::JinjaTokenAssign || it.type == JinjaToken::JinjaTokenEqual || it.type == JinjaToken::JinjaTokenNotEqual || + it.type == JinjaToken::JinjaTokenIn || it.type == JinjaToken::JinjaTokenAnd || it.type == JinjaToken::JinjaTokenOr) { AssertInFastLLM(vars.size() > 1, "Jinja Error: expression error."); @@ -469,7 +478,7 @@ namespace fastllm { } var[iterId] = original; i = endPos; - } else if (curBlock.type == JinjaBlock::JinjaBlockIf) { + } else if (curBlock.type == JinjaBlock::JinjaBlockIf || curBlock.type == JinjaBlock::JinjaBlockType::JinjaBlockElseIf) { int cnt = 0; int elsePos = -1; int endPos = -1; @@ -480,6 +489,11 @@ namespace fastllm { if (cnt == 0) { elsePos = j; } + } else if (blocks[j].type == JinjaBlock::JinjaBlockType::JinjaBlockElseIf) { + if (cnt == 0) { + endPos = j; + break; + } } else if (blocks[j].type == JinjaBlock::JinjaBlockType::JinjaBlockEndIf) { if ((cnt--) == 0) { endPos = j; @@ -506,7 +520,10 @@ namespace fastllm { Parse(elsePos + 1, endPos, var, ret); } } - i = endPos; + if (blocks[endPos].type == JinjaBlock::JinjaBlockType::JinjaBlockElseIf) + i = endPos - 1; + else + i = endPos; } else if (curBlock.type == JinjaBlock::JinjaBlockSet) { // 目前仅支持 "set 变量 = 表达式" 格式 AssertInFastLLM(