diff --git a/include/fastllm.h b/include/fastllm.h index 16820335..f730171a 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -46,6 +46,7 @@ namespace fastllm { float temperature = 1.0; // 温度参数,一般在0.1 ~ 1.0之间,设大这个参数可以带来结果的多样性 bool output_logits = false; // 是否返回logits bool enable_hash_id = false; // 给会话添加hash id + bool add_special_tokens = true; // prompt添加special tokens(chatglm模型生效) std::multiset stop_token_ids; bool IsSimpleGreedy() const { diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index 3cc7e5b9..57912b8e 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -114,7 +114,9 @@ namespace fastllm { std::vector results; LastTokensManager tokens(1, generationConfig.last_n); int promptLen = lastPromptTokens + inputTokens[0].size(), index = 0; - FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, inputIds, attentionMask, positionIds); + int add_special_tokens = generationConfig.add_special_tokens? 1: 0; + FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}}, + inputIds, attentionMask, positionIds); while (true) { auto st = std::chrono::system_clock::now(); int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens); @@ -146,7 +148,8 @@ namespace fastllm { results.clear(); inputTokens[0] = std::vector {(float)ret}; - FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, inputIds, attentionMask, positionIds); + FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}}, + inputIds, attentionMask, positionIds); if (index == generationConfig.output_token_limit) { break; } @@ -223,6 +226,7 @@ namespace fastllm { } params[0]["index"] = 0; int index = 0; + params[0]["add_special_tokens"] = generationConfig.add_special_tokens? 1: 0; LastTokensManager tokensManager (batch, generationConfig.last_n); std::vector isEnding = std::vector (batch, false); diff --git a/src/models/chatglm.cpp b/src/models/chatglm.cpp index 5e26644d..2e31c55f 100644 --- a/src/models/chatglm.cpp +++ b/src/models/chatglm.cpp @@ -751,16 +751,19 @@ namespace fastllm { int index = params.find("index")->second; int promptLen = params.find("promptLen")->second; + bool add_special_tokens = params.find("add_special_tokens")->second == 0? false: true; if (index == 0) { - for (auto &ids: inputTokens) { - if (GetVersion() == 1) { - ids.push_back(gmask_token_id); - ids.push_back(bos_token_id); - } else if (GetVersion() == 2) { - if (ids.size() < 2 || ids[0] != this->gmask_token_id || ids[1] != this->bos_token_id) { - ids.insert(ids.begin(), this->bos_token_id); - ids.insert(ids.begin(), this->gmask_token_id); + if (add_special_tokens) { + for (auto &ids: inputTokens) { + if (GetVersion() == 1) { + ids.push_back(gmask_token_id); + ids.push_back(bos_token_id); + } else if (GetVersion() == 2) { + if (ids.size() < 2 || ids[0] != this->gmask_token_id || ids[1] != this->bos_token_id) { + ids.insert(ids.begin(), this->bos_token_id); + ids.insert(ids.begin(), this->gmask_token_id); + } } } } @@ -809,12 +812,17 @@ namespace fastllm { int batch = inputTokens.size(); int index = params[0].find("index")->second; + bool add_special_tokens = params[0].find("add_special_tokens")->second == 0? false: true; + int special_tokens_offset = 0; + if (add_special_tokens) { + special_tokens_offset = 2; + } if (index == 0) { std::vector seqLens; seqLens.resize(batch); int maxLen = 0; for (int i = 0; i < batch; i++) { - maxLen = std::max(maxLen, (int) inputTokens[i].size() + 2); + maxLen = std::max(maxLen, (int) inputTokens[i].size() + special_tokens_offset); seqLens[i] = (int) inputTokens[i].size(); } @@ -824,13 +832,15 @@ namespace fastllm { for (int i = 0; i < batch; i++) { if (GetVersion() == 1) { auto &tokens = inputTokens[i]; - int len = tokens.size(), base = maxLen - 2 - len; + int len = tokens.size(), base = maxLen - special_tokens_offset - len; for (int j = 0; j < len; j++) { ids[i * maxLen + base + j] = tokens[j]; } - ids[i * maxLen + base + len] = gmask_token_id; - ids[i * maxLen + base + len + 1] = bos_token_id; - len += 2; + if (add_special_tokens) { + ids[i * maxLen + base + len] = gmask_token_id; + ids[i * maxLen + base + len + 1] = bos_token_id; + } + len += special_tokens_offset; for (int j = 0; j < len - 1; j++) { vpids[i * 2 * maxLen + base + j] = j; } @@ -847,13 +857,15 @@ namespace fastllm { } } else { auto &tokens = inputTokens[i]; - int len = tokens.size(), base = maxLen - 2 - len; - ids[i * maxLen + base] = gmask_token_id; - ids[i * maxLen + base + 1] = bos_token_id; + int len = tokens.size(), base = maxLen - special_tokens_offset - len; + if (add_special_tokens) { + ids[i * maxLen + base] = gmask_token_id; + ids[i * maxLen + base + 1] = bos_token_id; + } for (int j = 0; j < len; j++) { - ids[i * maxLen + base + 2 + j] = tokens[j]; + ids[i * maxLen + base + special_tokens_offset + j] = tokens[j]; } - len += 2; + len += special_tokens_offset; for (int j = 0; j < len; j++) { vpids[i * 2 * maxLen + base + j] = j; }