From 016eb320b447c05fa2f5397b4bea7fcf23848138 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Mon, 23 Sep 2024 16:37:18 +0800 Subject: [PATCH] =?UTF-8?q?BERT=E7=B1=BB=E4=BF=AE=E6=94=B9=E4=B8=BA?= =?UTF-8?q?=E4=B8=80=E4=B8=AA=E6=94=AF=E6=8C=81embedding,=20reranker?= =?UTF-8?q?=E7=9A=84=E5=9F=BA=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/models/bert.h | 15 ++++- include/models/xlmroberta.h | 33 +++-------- src/models/bert.cpp | 87 +++++++++++++++-------------- src/models/xlmroberta.cpp | 106 ++++++++++++++++++------------------ tools/src/pytools.cpp | 2 +- 5 files changed, 121 insertions(+), 122 deletions(-) diff --git a/include/models/bert.h b/include/models/bert.h index ef70ecd..91c9ccb 100644 --- a/include/models/bert.h +++ b/include/models/bert.h @@ -6,6 +6,9 @@ #include "fastllm.h" namespace fastllm { + // 类BERT类大模型基础类 + // 支持Compute-Score,计算两个token序列的相似程度(用于reranker) + // 支持Embedding,生成token序列的向量 class BertModel: public basellm { public: BertModel() {}; @@ -16,8 +19,10 @@ namespace fastllm { void InitParams(); // 初始化参数信息 + void Normalize(float *data, int dataLen); + // 推理 - std::vector > ForwardAll( + virtual std::vector > ForwardAll( const Data &inputIds, const Data &attentionMask, const Data &tokenTypeIds, @@ -34,6 +39,14 @@ namespace fastllm { const LastTokensManager &lastTokens = LastTokensManager(), std::vector *logits = nullptr); + virtual void FillBertInputsBatch(const std::vector > &tokens, + Data &inputIds, Data &attentionMask, Data &tokenTypeIds, Data &positionIds); + + // 计算相似分数 + // tokens: 输入tokens, tokens[i]代表第i个输入的token序列 + // ret: ret[i]代表第i个输入的相似度 + std::vector ComputeScore(std::vector > tokens); + std::vector EmbeddingSentence(const std::vector &tokens, bool normalize); std::vector > EmbeddingSentenceBatch(const std::vector > &tokens, bool normalize); diff --git a/include/models/xlmroberta.h b/include/models/xlmroberta.h index 331a25e..c69cf74 100644 --- a/include/models/xlmroberta.h +++ b/include/models/xlmroberta.h @@ -3,10 +3,11 @@ #define FASTLLM_XLMROBERTA_H #include "basellm.h" +#include "bert.h" #include "fastllm.h" namespace fastllm { - class XlmRobertaModel : basellm { + class XlmRobertaModel : BertModel { public: XlmRobertaModel(); @@ -16,36 +17,16 @@ namespace fastllm { void InitParams(); // 初始化参数信息 - // 推理 - int Forward( - const Data &inputIds, - const Data &attentionMask, - const Data &positionIds, - std::vector > &pastKeyValues, - const GenerationConfig &generationConfig = GenerationConfig(), - const LastTokensManager &lastTokens = LastTokensManager(), - std::vector *logits = nullptr) {return 0;} - - std::string MakeInput(const std::string &history, int round, const std::string &input) {return "";} - std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output) {return "";} - - // 计算相似分数 - // tokens: 输入tokens, tokens[i]代表第i个输入的token序列 - // ret: ret[i]代表第i个输入的相似度 - std::vector ComputeScore(std::vector > tokens); + void FillBertInputsBatch(const std::vector > &tokens, + Data &inputIds, Data &attentionMask, Data &tokenTypeIds, Data &positionIds); // 推理 - std::vector Forward( + std::vector > ForwardAll( const Data &inputIds, const Data &attentionMask, const Data &tokenTypeIds, - const Data &positionIds); - - std::vector EmbeddingSentence(const std::string &context); - - std::vector > EmbeddingSentenceBatch(const std::vector &contexts); - - void LoadFromFile(const std::string &fileName); // 从文件读取 + const Data &positionIds, + bool normalize); void WarmUp(); // 预热 diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 02875a8..66fa054 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -31,17 +31,19 @@ namespace fastllm { this->head_dim = embed_dim / num_attention_heads; } - void Normalize(float *data, int dataLen) - { + void BertModel::Normalize(float *data, int dataLen) { float sum = 0.0; - for(int i = 0; i < dataLen; i++) + for (int i = 0; i < dataLen; i++) { sum += data[i] * data[i]; - - if (sum < 1e-6) sum = 1e-6; - else sum = sqrt(sum); - - for(int i = 0; i < dataLen; i++) + } + if (sum < 1e-6) { + sum = 1e-6; + } else { + sum = sqrt(sum); + } + for (int i = 0; i < dataLen; i++) { data[i] = data[i] / sum; + } } std::vector > BertModel::ForwardAll( @@ -128,20 +130,18 @@ namespace fastllm { std::vector > ret; ret.resize(batch, std::vector (outputDim, 0.0f)); for (int i = 0; i < batch; i++) { - if(normalize) Normalize(fret + i * outputDim, outputDim); + if (normalize) { + Normalize(fret + i * outputDim, outputDim); + } memcpy(ret[i].data(), fret + i * outputDim, outputDim * sizeof(float)); } return ret; } - std::vector BertModel::EmbeddingSentence(const std::vector &tokens, bool normalize) { - std::vector > tokenss; - tokenss.push_back(tokens); - return EmbeddingSentenceBatch(tokenss, normalize)[0]; - } - - std::vector > BertModel::EmbeddingSentenceBatch(const std::vector > &tokens, bool normalize) { + // 根据输入的input + void BertModel::FillBertInputsBatch(const std::vector > &tokens, + Data &inputIds, Data &attentionMask, Data &tokenTypeIds, Data &positionIds) { int batch = tokens.size(), len = 0; for (int i = 0; i < batch; i++) { len = std::max(len, (int)tokens[i].size()); @@ -160,12 +160,36 @@ namespace fastllm { position_ids[i * len + j] = j; } } + inputIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids)); + attentionMask.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask)); + tokenTypeIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids)); + positionIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids)); + } - fastllm::Data inputIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids); - fastllm::Data attentionMask = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask); - fastllm::Data tokenTypeIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids); - fastllm::Data positionIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids); + std::vector BertModel::ComputeScore(std::vector > tokens) { + fastllm::Data inputIds, attentionMask, tokenTypeIds, positionIds; + FillBertInputsBatch(tokens, inputIds, attentionMask, tokenTypeIds, positionIds); + auto ret = ForwardAll(inputIds, attentionMask, tokenTypeIds, positionIds, false); + std::vector lastRet; + for (int i = 0; i < ret.size(); i++) { + lastRet.push_back(ret[i][0]); + } + return lastRet; + } + + std::vector BertModel::EmbeddingSentence(const std::vector &tokens, bool normalize) { + std::vector > tokenss; + tokenss.push_back(tokens); + return EmbeddingSentenceBatch(tokenss, normalize)[0]; + } + std::vector > BertModel::EmbeddingSentenceBatch(const std::vector > &tokens, bool normalize) { + fastllm::Data inputIds, attentionMask, tokenTypeIds, positionIds; + FillBertInputsBatch(tokens, inputIds, attentionMask, tokenTypeIds, positionIds); + int batch = tokens.size(), len = 0; + for (int i = 0; i < batch; i++) { + len = std::max(len, (int)tokens[i].size()); + } return ForwardAll(inputIds, attentionMask, tokenTypeIds, positionIds, normalize); } @@ -186,27 +210,8 @@ namespace fastllm { } len = std::max(len, (int)tokens[i].size()); } - - std::vector ids = std::vector (batch * len, 0.0f); - std::vector seqLens = std::vector (batch, 0.0f); - std::vector token_type_ids = std::vector (batch * len, 0.0f); - std::vector attention_mask = std::vector (batch * len, -1e10f); - std::vector position_ids = std::vector (batch * len, 0.0f); - for (int i = 0; i < batch; i++) { - seqLens[i] = tokens[i].size(); - for (int j = 0; j < tokens[i].size(); j++) { - ids[i * len + j] = tokens[i][j]; - attention_mask[i * len + j] = 0; - position_ids[i * len + j] = j; - } - } - - fastllm::Data inputIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids); - fastllm::Data attentionMask = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask); - fastllm::Data tokenTypeIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids); - fastllm::Data positionIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids); - -// printf("bs = %d, len = %d\n", batch, len); ClearProfiler(); Forward(inputIds, attentionMask, tokenTypeIds, positionIds); PrintProfiler(); + fastllm::Data inputIds, attentionMask, tokenTypeIds, positionIds; + FillBertInputsBatch(tokens, inputIds, attentionMask, tokenTypeIds, positionIds); return ForwardAll(inputIds, attentionMask, tokenTypeIds, positionIds, normalize); } diff --git a/src/models/xlmroberta.cpp b/src/models/xlmroberta.cpp index ec97de1..0c03bc7 100644 --- a/src/models/xlmroberta.cpp +++ b/src/models/xlmroberta.cpp @@ -22,11 +22,6 @@ namespace fastllm { }; } - void XlmRobertaModel::LoadFromFile(const std::string &fileName) { - this->weight.LoadFromFile(fileName); - InitParams(); - } - void XlmRobertaModel::InitParams() { if (this->weight.dicts.find("layer_norm_eps") != this->weight.dicts.end()) { this->layer_norm_eps = atof(this->weight.dicts["layer_norm_eps"].c_str()); @@ -45,12 +40,51 @@ namespace fastllm { this->head_dim = embed_dim / num_attention_heads; } - std::vector XlmRobertaModel::Forward( + void XlmRobertaModel::FillBertInputsBatch(const std::vector > &tokens, + Data &inputIds, Data &attentionMask, Data &tokenTypeIds, Data &positionIds) { + int batch = tokens.size(), len = 0; + for (int i = 0; i < batch; i++) { + len = std::max(len, (int)tokens[i].size()); + } + + std::vector ids = std::vector (batch * len, 0.0f); + std::vector seqLens = std::vector (batch, 0.0f); + std::vector token_type_ids = std::vector (batch * len, 0.0f); + std::vector attention_mask = std::vector (batch * len, -1e10f); + std::vector position_ids = std::vector (batch * len, 0.0f); + for (int i = 0; i < batch; i++) { + seqLens[i] = tokens[i].size(); + for (int j = 0; j < tokens[i].size(); j++) { + ids[i * len + j] = tokens[i][j]; + attention_mask[i * len + j] = 0; + position_ids[i * len + j] = 2 + j; + } + } + inputIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids)); + attentionMask.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask)); + tokenTypeIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids)); + positionIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids)); + } + + void Normalize__(float *data, int dataLen) + { + float sum = 0.0; + for(int i = 0; i < dataLen; i++) + sum += data[i] * data[i]; + + if (sum < 1e-6) sum = 1e-6; + else sum = sqrt(sum); + + for(int i = 0; i < dataLen; i++) + data[i] = data[i] / sum; + } + + std::vector > XlmRobertaModel::ForwardAll( const Data &inputIds, const Data &attentionMask, const Data &tokenTypeIds, - const Data &positionIds) { - // embedding + const Data &positionIds, + bool normalize) { Data inputEmbeddings, tokenTypeEmbeddings, positionIdEmbeddings; Embedding(inputIds, this->weight["roberta.embeddings.word_embeddings.weight"], inputEmbeddings); Embedding(tokenTypeIds, this->weight["roberta.embeddings.token_type_embeddings.weight"], tokenTypeEmbeddings); @@ -59,7 +93,6 @@ namespace fastllm { AddTo(inputEmbeddings, positionIdEmbeddings); Data hiddenStates, firstStates; LayerNorm(inputEmbeddings, this->weight["roberta.embeddings.LayerNorm.weight"], this->weight["roberta.embeddings.LayerNorm.bias"], -1, hiddenStates); - Data q, k, v, qk, qkv, attnOutput, inter, pooler, logits; for (int i = 0; i < this->block_cnt; i++) { std::string queryWeightName = "roberta.encoder.layer." + std::to_string(i) + ".attention.self.query.weight"; @@ -117,9 +150,13 @@ namespace fastllm { Split(hiddenStates, 1, 0, 1, firstStates); firstStates.Reshape({firstStates.dims[0], -1}); - Linear(firstStates, this->weight["classifier.dense.weight"], this->weight["classifier.dense.bias"], pooler); - TanH(pooler, pooler); - Linear(pooler, this->weight["classifier.out_proj.weight"], this->weight["classifier.out_proj.bias"], logits); + if (this->weight.weight.find("classifier.dense.weight") != this->weight.weight.end()) { + Linear(firstStates, this->weight["classifier.dense.weight"], this->weight["classifier.dense.bias"], pooler); + TanH(pooler, pooler); + Linear(pooler, this->weight["classifier.out_proj.weight"], this->weight["classifier.out_proj.bias"], logits); + } else { + Mul(firstStates, 1.0f, logits); + } logits.ToDevice(DataDevice::CPU); float *fret = (float*)logits.cpuData; @@ -127,49 +164,12 @@ namespace fastllm { std::vector > ret; ret.resize(batch, std::vector (outputDim, 0.0f)); for (int i = 0; i < batch; i++) { - memcpy(ret[i].data(), fret + i * outputDim, outputDim * sizeof(float)); - } - - std::vector lastRet; - for (int i = 0; i < batch; i++) { - lastRet.push_back(ret[i][0]); - } - - return lastRet; - } - - std::vector XlmRobertaModel::ComputeScore(std::vector > tokens) { - int batch = tokens.size(), maxLen = tokens[0].size(); - for (int i = 0; i < tokens.size(); i++) { - maxLen = std::max(maxLen, (int)tokens[i].size()); - } - std::vector inputIds = std::vector (batch * maxLen, 1.0f); - std::vector attentionMasks = std::vector (batch * maxLen, -10000.0f); - std::vector positionIds = std::vector (batch * maxLen, 0.0f); - std::vector tokenTypeIds = std::vector (batch * maxLen, 0.0f); - for (int i = 0; i < batch; i++) { - for (int j = 0; j < (int)tokens[i].size(); j++) { - inputIds[i * maxLen + j] = tokens[i][j]; - attentionMasks[i * maxLen + j] = 0.0f; - positionIds[i * maxLen + j] = 2 + j; + if (normalize) { + Normalize(fret + i * outputDim, outputDim); } + memcpy(ret[i].data(), fret + i * outputDim, outputDim * sizeof(float)); } - - fastllm::Data inputIdsData = fastllm::Data (fastllm::DataType::FLOAT32, {batch, maxLen}, inputIds); - fastllm::Data attentionMasksData = fastllm::Data (fastllm::DataType::FLOAT32, {batch, maxLen}, attentionMasks); - fastllm::Data positionIdsData = fastllm::Data (fastllm::DataType::FLOAT32, {batch, maxLen}, positionIds); - fastllm::Data tokenTypeIdsData = fastllm::Data (fastllm::DataType::FLOAT32, {batch, maxLen}, tokenTypeIds); - return Forward(inputIdsData, attentionMasksData, tokenTypeIdsData, positionIdsData); - } - - std::vector XlmRobertaModel::EmbeddingSentence(const std::string &context) { - std::vector contexts; - contexts.push_back(context); - return EmbeddingSentenceBatch(contexts)[0]; - } - - std::vector > XlmRobertaModel::EmbeddingSentenceBatch(const std::vector &contexts) { - return {}; + return ret; } void XlmRobertaModel::WarmUp() { diff --git a/tools/src/pytools.cpp b/tools/src/pytools.cpp index 23e1878..378f4a2 100644 --- a/tools/src/pytools.cpp +++ b/tools/src/pytools.cpp @@ -450,7 +450,7 @@ extern "C" { } DLL_EXPORT float* reranker_compute_score(int modelId, int batch, int *seqLens, int *tokens) { - fastllm::XlmRobertaModel *model = (fastllm::XlmRobertaModel*)models.GetModel(modelId); + fastllm::BertModel *model = (fastllm::BertModel*)models.GetModel(modelId); std::vector > inputIds; inputIds.resize(batch); int pos = 0;