Skip to content

Commit

Permalink
Merge pull request #332 from siemonchan/qwen
Browse files Browse the repository at this point in the history
Update tiktoken for QWen
  • Loading branch information
ztxz16 authored Sep 25, 2023
2 parents 8862c81 + c4dff3d commit 909b4d9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 36 deletions.
2 changes: 2 additions & 0 deletions include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ namespace fastllm {

void TryMergePairs(std::vector<Symbol> &symbols, int l, int r, std::priority_queue <SymbolPairs> &q); // 插入备选symbol

int GetRank(std::vector<Symbol> &symbols, std::vector<std::pair<int, int>> &partitions, int idx, int skip);

void Insert(const std::string &s, int tokenId, float score = 1.0f); // 插入一个token

Data Encode(const std::string &s); // 编码
Expand Down
74 changes: 38 additions & 36 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,18 @@ namespace fastllm {
q.push(SymbolPairs(now->score, l, r, symbols[l].len + symbols[r].len));
}

int Tokenizer::GetRank(std::vector<Symbol> &symbols, std::vector<std::pair<int, int>> &partitions, int idx, int skip) {
if (idx + skip + 2 >= partitions.size()) {
return std::numeric_limits<int>::max();
}
auto s = symbols[0].s + symbols[0].pos;
std::string key(s + partitions[idx].first, s + partitions[idx + skip + 2].first);
if (stringToTokenDict.find(key) != stringToTokenDict.end()) {
return stringToTokenDict[key];
}
return std::numeric_limits<int>::max();
}

Data Tokenizer::Encode(const std::string &ori) {
if (this->type == TokenizerType::BPE) {
std::string blank = "";
Expand Down Expand Up @@ -940,48 +952,38 @@ namespace fastllm {
if (i == sep.back().first) {
if (!symbols.empty()) {
symbols.back().next = -1;
std::priority_queue<SymbolPairs> workQueue;
for (int i = 1; i < symbols.size(); i++) {
TryMergePairs(symbols, i - 1, i, workQueue);
std::string cur = ori.substr(i - symbols.size(), symbols.size());
std::vector<std::pair<int, int>> partitions(symbols.size() + 1);
for (int j = 0; j <= (int) symbols.size(); j++) {
partitions[j] = std::make_pair(j, std::numeric_limits<int>::max());
}

while (!workQueue.empty()) {
auto top = workQueue.top();
workQueue.pop();
if (symbols[top.l].len == 0 || symbols[top.r].len == 0 ||
symbols[top.l].len + symbols[top.r].len != top.size) {
continue;
}

for (int i = symbols[top.r].pos; i < symbols[top.r].pos + symbols[top.r].len; i++) {
symbols[top.l].node = symbols[top.l].node->next[symbols[top.r].s[i]];
}
symbols[top.l].len += symbols[top.r].len;
symbols[top.r].len = 0;
symbols[top.l].next = symbols[top.r].next;
if (symbols[top.r].next >= 0) {
symbols[symbols[top.r].next].prev = top.l;
}

TryMergePairs(symbols, symbols[top.l].prev, top.l, workQueue);
TryMergePairs(symbols, top.l, symbols[top.l].next, workQueue);
for (int j = 0; j < partitions.size() - 2; j++) {
partitions[j].second = GetRank(symbols, partitions, j, 0);
}

for (int i = 0; i < symbols.size(); i++) {
if (symbols[i].len > 0) {
v.push_back(symbols[i].node->tokenId);
} else if (symbols[i].node == nullptr) {
// 未识别的字符
uint8_t c = (uint8_t) (symbols[i].s[symbols[i].pos]);
std::string now = "<0x00>";
now[3] = (c / 16 > 9 ? ('A' + c / 16 - 10) : ('0' + c / 16));
now[4] = (c % 16 > 9 ? ('A' + c % 16 - 10) : ('0' + c % 16));
if (stringToTokenDict.find(now) != stringToTokenDict.end()) {
v.push_back(stringToTokenDict[now]);
while (partitions.size() > 1) {
int min_rank = std::numeric_limits<int>::max();
int min_rank_idx = 0;
for (int j = 0; j < partitions.size() - 1; ++j) {
if (partitions[j].second < min_rank) {
min_rank = partitions[j].second;
min_rank_idx = j;
}
}
if (min_rank != std::numeric_limits<int>::max()) {
partitions[min_rank_idx].second = GetRank(symbols, partitions, min_rank_idx, 1);
if (min_rank_idx > 0) {
partitions[min_rank_idx - 1].second = GetRank(symbols, partitions, min_rank_idx - 1, 1);
}
partitions.erase(partitions.begin() + min_rank_idx + 1);
} else {
break;
}
}
symbols.clear();
for (int j = 0; j < partitions.size() - 1; j++) {
std::string key = cur.substr(partitions[j].first, partitions[j + 1].first - partitions[j].first);
v.push_back((float) stringToTokenDict[key]);
}
}

std::string special = ori.substr(sep.back().first, sep.back().second);
Expand Down

0 comments on commit 909b4d9

Please sign in to comment.