Skip to content

Commit

Permalink
BERT类修改为一个支持embedding, reranker的基类
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Sep 23, 2024
1 parent 4736a6d commit 016eb32
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 122 deletions.
15 changes: 14 additions & 1 deletion include/models/bert.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#include "fastllm.h"

namespace fastllm {
// 类BERT类大模型基础类
// 支持Compute-Score,计算两个token序列的相似程度(用于reranker)
// 支持Embedding,生成token序列的向量
class BertModel: public basellm {
public:
BertModel() {};
Expand All @@ -16,8 +19,10 @@ namespace fastllm {

void InitParams(); // 初始化参数信息

void Normalize(float *data, int dataLen);

// 推理
std::vector <std::vector <float> > ForwardAll(
virtual std::vector <std::vector <float> > ForwardAll(
const Data &inputIds,
const Data &attentionMask,
const Data &tokenTypeIds,
Expand All @@ -34,6 +39,14 @@ namespace fastllm {
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <float> *logits = nullptr);

virtual void FillBertInputsBatch(const std::vector <std::vector <int> > &tokens,
Data &inputIds, Data &attentionMask, Data &tokenTypeIds, Data &positionIds);

// 计算相似分数
// tokens: 输入tokens, tokens[i]代表第i个输入的token序列
// ret: ret[i]代表第i个输入的相似度
std::vector <float> ComputeScore(std::vector <std::vector <int> > tokens);

std::vector <float> EmbeddingSentence(const std::vector <int> &tokens, bool normalize);

std::vector <std::vector <float> > EmbeddingSentenceBatch(const std::vector <std::vector <int> > &tokens, bool normalize);
Expand Down
33 changes: 7 additions & 26 deletions include/models/xlmroberta.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
#define FASTLLM_XLMROBERTA_H

#include "basellm.h"
#include "bert.h"
#include "fastllm.h"

namespace fastllm {
class XlmRobertaModel : basellm {
class XlmRobertaModel : BertModel {
public:
XlmRobertaModel();

Expand All @@ -16,36 +17,16 @@ namespace fastllm {

void InitParams(); // 初始化参数信息

// 推理
int Forward(
const Data &inputIds,
const Data &attentionMask,
const Data &positionIds,
std::vector <std::pair <Data, Data> > &pastKeyValues,
const GenerationConfig &generationConfig = GenerationConfig(),
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <float> *logits = nullptr) {return 0;}

std::string MakeInput(const std::string &history, int round, const std::string &input) {return "";}
std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output) {return "";}

// 计算相似分数
// tokens: 输入tokens, tokens[i]代表第i个输入的token序列
// ret: ret[i]代表第i个输入的相似度
std::vector <float> ComputeScore(std::vector <std::vector <int> > tokens);
void FillBertInputsBatch(const std::vector <std::vector <int> > &tokens,
Data &inputIds, Data &attentionMask, Data &tokenTypeIds, Data &positionIds);

// 推理
std::vector <float> Forward(
std::vector <std::vector <float> > ForwardAll(
const Data &inputIds,
const Data &attentionMask,
const Data &tokenTypeIds,
const Data &positionIds);

std::vector <float> EmbeddingSentence(const std::string &context);

std::vector <std::vector <float> > EmbeddingSentenceBatch(const std::vector <std::string> &contexts);

void LoadFromFile(const std::string &fileName); // 从文件读取
const Data &positionIds,
bool normalize);

void WarmUp(); // 预热

Expand Down
87 changes: 46 additions & 41 deletions src/models/bert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,19 @@ namespace fastllm {
this->head_dim = embed_dim / num_attention_heads;
}

void Normalize(float *data, int dataLen)
{
void BertModel::Normalize(float *data, int dataLen) {
float sum = 0.0;
for(int i = 0; i < dataLen; i++)
for (int i = 0; i < dataLen; i++) {
sum += data[i] * data[i];

if (sum < 1e-6) sum = 1e-6;
else sum = sqrt(sum);

for(int i = 0; i < dataLen; i++)
}
if (sum < 1e-6) {
sum = 1e-6;
} else {
sum = sqrt(sum);
}
for (int i = 0; i < dataLen; i++) {
data[i] = data[i] / sum;
}
}

std::vector <std::vector <float> > BertModel::ForwardAll(
Expand Down Expand Up @@ -128,20 +130,18 @@ namespace fastllm {
std::vector <std::vector <float> > ret;
ret.resize(batch, std::vector <float> (outputDim, 0.0f));
for (int i = 0; i < batch; i++) {
if(normalize) Normalize(fret + i * outputDim, outputDim);
if (normalize) {
Normalize(fret + i * outputDim, outputDim);
}
memcpy(ret[i].data(), fret + i * outputDim, outputDim * sizeof(float));
}

return ret;
}

std::vector <float> BertModel::EmbeddingSentence(const std::vector <int> &tokens, bool normalize) {
std::vector <std::vector <int> > tokenss;
tokenss.push_back(tokens);
return EmbeddingSentenceBatch(tokenss, normalize)[0];
}

std::vector <std::vector <float> > BertModel::EmbeddingSentenceBatch(const std::vector <std::vector <int> > &tokens, bool normalize) {
// 根据输入的input
void BertModel::FillBertInputsBatch(const std::vector <std::vector <int> > &tokens,
Data &inputIds, Data &attentionMask, Data &tokenTypeIds, Data &positionIds) {
int batch = tokens.size(), len = 0;
for (int i = 0; i < batch; i++) {
len = std::max(len, (int)tokens[i].size());
Expand All @@ -160,12 +160,36 @@ namespace fastllm {
position_ids[i * len + j] = j;
}
}
inputIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids));
attentionMask.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask));
tokenTypeIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids));
positionIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids));
}

fastllm::Data inputIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids);
fastllm::Data attentionMask = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask);
fastllm::Data tokenTypeIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids);
fastllm::Data positionIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids);
std::vector <float> BertModel::ComputeScore(std::vector <std::vector <int> > tokens) {
fastllm::Data inputIds, attentionMask, tokenTypeIds, positionIds;
FillBertInputsBatch(tokens, inputIds, attentionMask, tokenTypeIds, positionIds);
auto ret = ForwardAll(inputIds, attentionMask, tokenTypeIds, positionIds, false);
std::vector <float> lastRet;
for (int i = 0; i < ret.size(); i++) {
lastRet.push_back(ret[i][0]);
}
return lastRet;
}

std::vector <float> BertModel::EmbeddingSentence(const std::vector <int> &tokens, bool normalize) {
std::vector <std::vector <int> > tokenss;
tokenss.push_back(tokens);
return EmbeddingSentenceBatch(tokenss, normalize)[0];
}

std::vector <std::vector <float> > BertModel::EmbeddingSentenceBatch(const std::vector <std::vector <int> > &tokens, bool normalize) {
fastllm::Data inputIds, attentionMask, tokenTypeIds, positionIds;
FillBertInputsBatch(tokens, inputIds, attentionMask, tokenTypeIds, positionIds);
int batch = tokens.size(), len = 0;
for (int i = 0; i < batch; i++) {
len = std::max(len, (int)tokens[i].size());
}
return ForwardAll(inputIds, attentionMask, tokenTypeIds, positionIds, normalize);
}

Expand All @@ -186,27 +210,8 @@ namespace fastllm {
}
len = std::max(len, (int)tokens[i].size());
}

std::vector <float> ids = std::vector <float> (batch * len, 0.0f);
std::vector <float> seqLens = std::vector <float> (batch, 0.0f);
std::vector <float> token_type_ids = std::vector <float> (batch * len, 0.0f);
std::vector <float> attention_mask = std::vector <float> (batch * len, -1e10f);
std::vector <float> position_ids = std::vector <float> (batch * len, 0.0f);
for (int i = 0; i < batch; i++) {
seqLens[i] = tokens[i].size();
for (int j = 0; j < tokens[i].size(); j++) {
ids[i * len + j] = tokens[i][j];
attention_mask[i * len + j] = 0;
position_ids[i * len + j] = j;
}
}

fastllm::Data inputIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids);
fastllm::Data attentionMask = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask);
fastllm::Data tokenTypeIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids);
fastllm::Data positionIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids);

// printf("bs = %d, len = %d\n", batch, len); ClearProfiler(); Forward(inputIds, attentionMask, tokenTypeIds, positionIds); PrintProfiler();
fastllm::Data inputIds, attentionMask, tokenTypeIds, positionIds;
FillBertInputsBatch(tokens, inputIds, attentionMask, tokenTypeIds, positionIds);
return ForwardAll(inputIds, attentionMask, tokenTypeIds, positionIds, normalize);
}

Expand Down
106 changes: 53 additions & 53 deletions src/models/xlmroberta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ namespace fastllm {
};
}

void XlmRobertaModel::LoadFromFile(const std::string &fileName) {
this->weight.LoadFromFile(fileName);
InitParams();
}

void XlmRobertaModel::InitParams() {
if (this->weight.dicts.find("layer_norm_eps") != this->weight.dicts.end()) {
this->layer_norm_eps = atof(this->weight.dicts["layer_norm_eps"].c_str());
Expand All @@ -45,12 +40,51 @@ namespace fastllm {
this->head_dim = embed_dim / num_attention_heads;
}

std::vector <float> XlmRobertaModel::Forward(
void XlmRobertaModel::FillBertInputsBatch(const std::vector <std::vector <int> > &tokens,
Data &inputIds, Data &attentionMask, Data &tokenTypeIds, Data &positionIds) {
int batch = tokens.size(), len = 0;
for (int i = 0; i < batch; i++) {
len = std::max(len, (int)tokens[i].size());
}

std::vector <float> ids = std::vector <float> (batch * len, 0.0f);
std::vector <float> seqLens = std::vector <float> (batch, 0.0f);
std::vector <float> token_type_ids = std::vector <float> (batch * len, 0.0f);
std::vector <float> attention_mask = std::vector <float> (batch * len, -1e10f);
std::vector <float> position_ids = std::vector <float> (batch * len, 0.0f);
for (int i = 0; i < batch; i++) {
seqLens[i] = tokens[i].size();
for (int j = 0; j < tokens[i].size(); j++) {
ids[i * len + j] = tokens[i][j];
attention_mask[i * len + j] = 0;
position_ids[i * len + j] = 2 + j;
}
}
inputIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids));
attentionMask.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask));
tokenTypeIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids));
positionIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids));
}

void Normalize__(float *data, int dataLen)
{
float sum = 0.0;
for(int i = 0; i < dataLen; i++)
sum += data[i] * data[i];

if (sum < 1e-6) sum = 1e-6;
else sum = sqrt(sum);

for(int i = 0; i < dataLen; i++)
data[i] = data[i] / sum;
}

std::vector <std::vector <float> > XlmRobertaModel::ForwardAll(
const Data &inputIds,
const Data &attentionMask,
const Data &tokenTypeIds,
const Data &positionIds) {
// embedding
const Data &positionIds,
bool normalize) {
Data inputEmbeddings, tokenTypeEmbeddings, positionIdEmbeddings;
Embedding(inputIds, this->weight["roberta.embeddings.word_embeddings.weight"], inputEmbeddings);
Embedding(tokenTypeIds, this->weight["roberta.embeddings.token_type_embeddings.weight"], tokenTypeEmbeddings);
Expand All @@ -59,7 +93,6 @@ namespace fastllm {
AddTo(inputEmbeddings, positionIdEmbeddings);
Data hiddenStates, firstStates;
LayerNorm(inputEmbeddings, this->weight["roberta.embeddings.LayerNorm.weight"], this->weight["roberta.embeddings.LayerNorm.bias"], -1, hiddenStates);

Data q, k, v, qk, qkv, attnOutput, inter, pooler, logits;
for (int i = 0; i < this->block_cnt; i++) {
std::string queryWeightName = "roberta.encoder.layer." + std::to_string(i) + ".attention.self.query.weight";
Expand Down Expand Up @@ -117,59 +150,26 @@ namespace fastllm {

Split(hiddenStates, 1, 0, 1, firstStates);
firstStates.Reshape({firstStates.dims[0], -1});
Linear(firstStates, this->weight["classifier.dense.weight"], this->weight["classifier.dense.bias"], pooler);
TanH(pooler, pooler);
Linear(pooler, this->weight["classifier.out_proj.weight"], this->weight["classifier.out_proj.bias"], logits);
if (this->weight.weight.find("classifier.dense.weight") != this->weight.weight.end()) {
Linear(firstStates, this->weight["classifier.dense.weight"], this->weight["classifier.dense.bias"], pooler);
TanH(pooler, pooler);
Linear(pooler, this->weight["classifier.out_proj.weight"], this->weight["classifier.out_proj.bias"], logits);
} else {
Mul(firstStates, 1.0f, logits);
}

logits.ToDevice(DataDevice::CPU);
float *fret = (float*)logits.cpuData;
int batch = logits.dims[0], outputDim = logits.dims[1];
std::vector <std::vector <float> > ret;
ret.resize(batch, std::vector <float> (outputDim, 0.0f));
for (int i = 0; i < batch; i++) {
memcpy(ret[i].data(), fret + i * outputDim, outputDim * sizeof(float));
}

std::vector <float> lastRet;
for (int i = 0; i < batch; i++) {
lastRet.push_back(ret[i][0]);
}

return lastRet;
}

std::vector <float> XlmRobertaModel::ComputeScore(std::vector <std::vector <int> > tokens) {
int batch = tokens.size(), maxLen = tokens[0].size();
for (int i = 0; i < tokens.size(); i++) {
maxLen = std::max(maxLen, (int)tokens[i].size());
}
std::vector <float> inputIds = std::vector <float> (batch * maxLen, 1.0f);
std::vector <float> attentionMasks = std::vector <float> (batch * maxLen, -10000.0f);
std::vector <float> positionIds = std::vector <float> (batch * maxLen, 0.0f);
std::vector <float> tokenTypeIds = std::vector <float> (batch * maxLen, 0.0f);
for (int i = 0; i < batch; i++) {
for (int j = 0; j < (int)tokens[i].size(); j++) {
inputIds[i * maxLen + j] = tokens[i][j];
attentionMasks[i * maxLen + j] = 0.0f;
positionIds[i * maxLen + j] = 2 + j;
if (normalize) {
Normalize(fret + i * outputDim, outputDim);
}
memcpy(ret[i].data(), fret + i * outputDim, outputDim * sizeof(float));
}

fastllm::Data inputIdsData = fastllm::Data (fastllm::DataType::FLOAT32, {batch, maxLen}, inputIds);
fastllm::Data attentionMasksData = fastllm::Data (fastllm::DataType::FLOAT32, {batch, maxLen}, attentionMasks);
fastllm::Data positionIdsData = fastllm::Data (fastllm::DataType::FLOAT32, {batch, maxLen}, positionIds);
fastllm::Data tokenTypeIdsData = fastllm::Data (fastllm::DataType::FLOAT32, {batch, maxLen}, tokenTypeIds);
return Forward(inputIdsData, attentionMasksData, tokenTypeIdsData, positionIdsData);
}

std::vector <float> XlmRobertaModel::EmbeddingSentence(const std::string &context) {
std::vector <std::string> contexts;
contexts.push_back(context);
return EmbeddingSentenceBatch(contexts)[0];
}

std::vector <std::vector <float> > XlmRobertaModel::EmbeddingSentenceBatch(const std::vector <std::string> &contexts) {
return {};
return ret;
}

void XlmRobertaModel::WarmUp() {
Expand Down
2 changes: 1 addition & 1 deletion tools/src/pytools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ extern "C" {
}

DLL_EXPORT float* reranker_compute_score(int modelId, int batch, int *seqLens, int *tokens) {
fastllm::XlmRobertaModel *model = (fastllm::XlmRobertaModel*)models.GetModel(modelId);
fastllm::BertModel *model = (fastllm::BertModel*)models.GetModel(modelId);
std::vector <std::vector <int> > inputIds;
inputIds.resize(batch);
int pos = 0;
Expand Down

0 comments on commit 016eb32

Please sign in to comment.