Skip to content

Commit

Permalink
Merge pull request #458 from levinxo/master
Browse files Browse the repository at this point in the history
添加add_special_tokens选项,默认true,支持chatglm
  • Loading branch information
ztxz16 authored May 22, 2024
2 parents 6656e95 + 9c88e68 commit e6fd13c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 20 deletions.
1 change: 1 addition & 0 deletions include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ namespace fastllm {
float temperature = 1.0; // 温度参数,一般在0.1 ~ 1.0之间,设大这个参数可以带来结果的多样性
bool output_logits = false; // 是否返回logits
bool enable_hash_id = false; // 给会话添加hash id
bool add_special_tokens = true; // prompt添加special tokens(chatglm模型生效)
std::multiset <int> stop_token_ids;

bool IsSimpleGreedy() const {
Expand Down
8 changes: 6 additions & 2 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ namespace fastllm {
std::vector<float> results;
LastTokensManager tokens(1, generationConfig.last_n);
int promptLen = lastPromptTokens + inputTokens[0].size(), index = 0;
FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, inputIds, attentionMask, positionIds);
int add_special_tokens = generationConfig.add_special_tokens? 1: 0;
FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}},
inputIds, attentionMask, positionIds);
while (true) {
auto st = std::chrono::system_clock::now();
int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens);
Expand Down Expand Up @@ -146,7 +148,8 @@ namespace fastllm {
results.clear();

inputTokens[0] = std::vector<float> {(float)ret};
FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, inputIds, attentionMask, positionIds);
FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}},
inputIds, attentionMask, positionIds);
if (index == generationConfig.output_token_limit) {
break;
}
Expand Down Expand Up @@ -223,6 +226,7 @@ namespace fastllm {
}
params[0]["index"] = 0;
int index = 0;
params[0]["add_special_tokens"] = generationConfig.add_special_tokens? 1: 0;

LastTokensManager tokensManager (batch, generationConfig.last_n);
std::vector <bool> isEnding = std::vector <bool> (batch, false);
Expand Down
48 changes: 30 additions & 18 deletions src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,16 +751,19 @@ namespace fastllm {

int index = params.find("index")->second;
int promptLen = params.find("promptLen")->second;
bool add_special_tokens = params.find("add_special_tokens")->second == 0? false: true;

if (index == 0) {
for (auto &ids: inputTokens) {
if (GetVersion() == 1) {
ids.push_back(gmask_token_id);
ids.push_back(bos_token_id);
} else if (GetVersion() == 2) {
if (ids.size() < 2 || ids[0] != this->gmask_token_id || ids[1] != this->bos_token_id) {
ids.insert(ids.begin(), this->bos_token_id);
ids.insert(ids.begin(), this->gmask_token_id);
if (add_special_tokens) {
for (auto &ids: inputTokens) {
if (GetVersion() == 1) {
ids.push_back(gmask_token_id);
ids.push_back(bos_token_id);
} else if (GetVersion() == 2) {
if (ids.size() < 2 || ids[0] != this->gmask_token_id || ids[1] != this->bos_token_id) {
ids.insert(ids.begin(), this->bos_token_id);
ids.insert(ids.begin(), this->gmask_token_id);
}
}
}
}
Expand Down Expand Up @@ -809,12 +812,17 @@ namespace fastllm {

int batch = inputTokens.size();
int index = params[0].find("index")->second;
bool add_special_tokens = params[0].find("add_special_tokens")->second == 0? false: true;
int special_tokens_offset = 0;
if (add_special_tokens) {
special_tokens_offset = 2;
}
if (index == 0) {
std::vector<int> seqLens;
seqLens.resize(batch);
int maxLen = 0;
for (int i = 0; i < batch; i++) {
maxLen = std::max(maxLen, (int) inputTokens[i].size() + 2);
maxLen = std::max(maxLen, (int) inputTokens[i].size() + special_tokens_offset);
seqLens[i] = (int) inputTokens[i].size();
}

Expand All @@ -824,13 +832,15 @@ namespace fastllm {
for (int i = 0; i < batch; i++) {
if (GetVersion() == 1) {
auto &tokens = inputTokens[i];
int len = tokens.size(), base = maxLen - 2 - len;
int len = tokens.size(), base = maxLen - special_tokens_offset - len;
for (int j = 0; j < len; j++) {
ids[i * maxLen + base + j] = tokens[j];
}
ids[i * maxLen + base + len] = gmask_token_id;
ids[i * maxLen + base + len + 1] = bos_token_id;
len += 2;
if (add_special_tokens) {
ids[i * maxLen + base + len] = gmask_token_id;
ids[i * maxLen + base + len + 1] = bos_token_id;
}
len += special_tokens_offset;
for (int j = 0; j < len - 1; j++) {
vpids[i * 2 * maxLen + base + j] = j;
}
Expand All @@ -847,13 +857,15 @@ namespace fastllm {
}
} else {
auto &tokens = inputTokens[i];
int len = tokens.size(), base = maxLen - 2 - len;
ids[i * maxLen + base] = gmask_token_id;
ids[i * maxLen + base + 1] = bos_token_id;
int len = tokens.size(), base = maxLen - special_tokens_offset - len;
if (add_special_tokens) {
ids[i * maxLen + base] = gmask_token_id;
ids[i * maxLen + base + 1] = bos_token_id;
}
for (int j = 0; j < len; j++) {
ids[i * maxLen + base + 2 + j] = tokens[j];
ids[i * maxLen + base + special_tokens_offset + j] = tokens[j];
}
len += 2;
len += special_tokens_offset;
for (int j = 0; j < len; j++) {
vpids[i * 2 * maxLen + base + j] = j;
}
Expand Down

0 comments on commit e6fd13c

Please sign in to comment.