Skip to content

Commit

Permalink
修改embedding模型和reranker模型的python接口
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Sep 22, 2024
1 parent 0222fa1 commit 5b89fab
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 3 deletions.
7 changes: 5 additions & 2 deletions include/models/bert.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ namespace fastllm {
const GenerationConfig &generationConfig = GenerationConfig(),
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <float> *logits = nullptr);


std::vector <float> EmbeddingSentence(const std::vector <int> &tokens, bool normalize);

std::vector <std::vector <float> > EmbeddingSentenceBatch(const std::vector <std::vector <int> > &tokens, bool normalize);

std::vector <float> EmbeddingSentence(const std::string &context, bool normalize);

std::vector <std::vector <float> > EmbeddingSentenceBatch(const std::vector <std::string> &contexts, bool normalize);
Expand All @@ -56,7 +60,6 @@ namespace fastllm {
int max_positions = 32768;
int block_cnt = 12;

WeightMap weight; // 权重
std::map <std::string, int> deviceMap;
};
}
Expand Down
34 changes: 34 additions & 0 deletions src/models/bert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,40 @@ namespace fastllm {
return ret;
}

std::vector <float> BertModel::EmbeddingSentence(const std::vector <int> &tokens, bool normalize) {
std::vector <std::vector <int> > tokenss;
tokenss.push_back(tokens);
return EmbeddingSentenceBatch(tokenss, normalize)[0];
}

std::vector <std::vector <float> > BertModel::EmbeddingSentenceBatch(const std::vector <std::vector <int> > &tokens, bool normalize) {
int batch = tokens.size(), len = 0;
for (int i = 0; i < batch; i++) {
len = std::max(len, (int)tokens[i].size());
}

std::vector <float> ids = std::vector <float> (batch * len, 0.0f);
std::vector <float> seqLens = std::vector <float> (batch, 0.0f);
std::vector <float> token_type_ids = std::vector <float> (batch * len, 0.0f);
std::vector <float> attention_mask = std::vector <float> (batch * len, -1e10f);
std::vector <float> position_ids = std::vector <float> (batch * len, 0.0f);
for (int i = 0; i < batch; i++) {
seqLens[i] = tokens[i].size();
for (int j = 0; j < tokens[i].size(); j++) {
ids[i * len + j] = tokens[i][j];
attention_mask[i * len + j] = 0;
position_ids[i * len + j] = j;
}
}

fastllm::Data inputIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids);
fastllm::Data attentionMask = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask);
fastllm::Data tokenTypeIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids);
fastllm::Data positionIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids);

return ForwardAll(inputIds, attentionMask, tokenTypeIds, positionIds, normalize);
}

std::vector <float> BertModel::EmbeddingSentence(const std::string &context, bool normalize) {
std::vector <std::string> contexts;
contexts.push_back(context);
Expand Down
26 changes: 25 additions & 1 deletion tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@
fastllm_lib.embedding_sentence.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_bool, ctypes.POINTER(ctypes.c_int)]
fastllm_lib.embedding_sentence.restype = ctypes.POINTER(ctypes.c_float)

fastllm_lib.embedding_tokens.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int), ctypes.c_bool, ctypes.POINTER(ctypes.c_int)]
fastllm_lib.embedding_tokens.restype = ctypes.POINTER(ctypes.c_float)

fastllm_lib.reranker_compute_score.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int)]
fastllm_lib.reranker_compute_score.restype = ctypes.POINTER(ctypes.c_float)

def softmax(a):
max_value = a[0]
for i in a:
Expand Down Expand Up @@ -1086,12 +1092,30 @@ def get_max_input_len(self):

def embedding_sentence(self, input: str, normalize = True):
embedding_len = ctypes.c_int(0)
embedding_c_float = fastllm_lib.embedding_sentence(self.model, input.encode(), normalize, embedding_len)
if (self.hf_tokenizer != None):
input_ids = self.hf_tokenizer(input, padding = True, truncation = True)['input_ids']
embedding_c_float = fastllm_lib.embedding_tokens(self.model, len(input_ids), (ctypes.c_int * len(input_ids))(*input_ids), normalize, embedding_len)
else:
embedding_c_float = fastllm_lib.embedding_sentence(self.model, input.encode(), normalize, embedding_len)
embedding = []
for i in range(embedding_len.value):
embedding.append(embedding_c_float[i])
#print("{:.7f}".format(embedding[i]), end=" ")
return embedding

def reranker_compute_score(self, pairs: List):
batch = len(pairs)
seq_lens = []
tokens = []
for i in range(batch):
input_ids = self.hf_tokenizer(pairs[i : i + 1], padding = True, truncation = True)['input_ids'][0]
seq_lens.append(len(input_ids))
tokens += input_ids
ret_c = fastllm_lib.reranker_compute_score(self.model, batch, (ctypes.c_int * len(seq_lens))(*seq_lens), (ctypes.c_int * len(tokens))(*tokens))
ret = []
for i in range(batch):
ret.append(ret_c[i])
return ret

def GraphNode(name: str,
type: str = "data",
Expand Down
31 changes: 31 additions & 0 deletions tools/src/pytools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,4 +435,35 @@ extern "C" {
*embeddingLen = result.size();
return fvalue;
}

DLL_EXPORT float* embedding_tokens(int modelId, int inputLen, int *input, bool normalize, int *embeddingLen) {
fastllm::BertModel *model = (fastllm::BertModel*)models.GetModel(modelId);
std::vector <int> tokens;
for (int i = 0; i < inputLen; i++) {
tokens.push_back(input[i]);
}
std::vector <float> result = model->EmbeddingSentence(tokens, normalize);
float *fvalue = new float[result.size()];
memcpy(fvalue, result.data(), result.size() * sizeof(float));
*embeddingLen = result.size();
return fvalue;
}

DLL_EXPORT float* reranker_compute_score(int modelId, int batch, int *seqLens, int *tokens) {
fastllm::XlmRobertaModel *model = (fastllm::XlmRobertaModel*)models.GetModel(modelId);
std::vector <std::vector <int> > inputIds;
inputIds.resize(batch);
int pos = 0;
for (int i = 0; i < batch; i++) {
for (int j = 0; j < seqLens[i]; j++) {
inputIds[i].push_back(tokens[pos++]);
}
}
auto ret = model->ComputeScore(inputIds);
float *fvalue = new float[batch];
for (int i = 0; i < batch; i++) {
fvalue[i] = ret[i];
}
return fvalue;
}
};

0 comments on commit 5b89fab

Please sign in to comment.