Skip to content

Commit

Permalink
优化tiktoken分词
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jun 6, 2024
1 parent 79d2016 commit e8103fd
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 16 deletions.
19 changes: 19 additions & 0 deletions include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,23 @@ namespace fastllm {
void SetKVCache();
};

struct PartitionLinkNode {
std::pair <int, int> *cur = nullptr;
PartitionLinkNode *next = nullptr;
PartitionLinkNode *prev = nullptr;
int id = -1;

PartitionLinkNode *Skip(int t) {
PartitionLinkNode *ret = this;
while (t--) {
if (ret != nullptr) {
ret = ret->next;
}
}
return ret;
}
};

struct Tokenizer {
enum TokenizerType {
BPE = 0,
Expand Down Expand Up @@ -415,6 +432,8 @@ namespace fastllm {

void TryMergePairs(std::vector<Symbol> &symbols, int l, int r, std::priority_queue <SymbolPairs> &q); // 插入备选symbol

int GetRank(std::vector <Symbol> &symbols, PartitionLinkNode *cur, int skip);

int GetRank(std::vector<Symbol> &symbols, std::vector<std::pair<int, int>> &partitions, int idx, int skip);

void Insert(const std::string &s, int tokenId, float score = 1.0f); // 插入一个token
Expand Down
66 changes: 50 additions & 16 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,19 @@ namespace fastllm {
q.push(SymbolPairs(now->score, l, r, symbols[l].len + symbols[r].len));
}

int Tokenizer::GetRank(std::vector <Symbol> &symbols, PartitionLinkNode *cur, int skip) {
auto nxt = cur->Skip(skip + 2);
if (nxt == nullptr) {
return std::numeric_limits<int>::max();
}
auto s = symbols[0].s + symbols[0].pos;
std::string key(s + cur->cur->first, s + nxt->cur->first);
if (stringToTokenDict.find(key) != stringToTokenDict.end()) {
return stringToTokenDict[key];
}
return std::numeric_limits<int>::max();
}

int Tokenizer::GetRank(std::vector<Symbol> &symbols, std::vector<std::pair<int, int>> &partitions, int idx, int skip) {
if (idx + skip + 2 >= partitions.size()) {
return std::numeric_limits<int>::max();
Expand Down Expand Up @@ -1614,42 +1627,63 @@ namespace fastllm {

std::vector<Symbol> symbols;
std::vector<float> v;

for (int i = 0; i <= ori.size(); i++) {
if (i == sep.back().first) {
if (!symbols.empty()) {
symbols.back().next = -1;
std::string cur = ori.substr(i - symbols.size(), symbols.size());
std::vector<std::pair<int, int>> partitions(symbols.size() + 1);
std::vector <PartitionLinkNode> nodes(symbols.size() + 1);
for (int j = 0; j <= (int) symbols.size(); j++) {
partitions[j] = std::make_pair(j, std::numeric_limits<int>::max());
}
for (int j = 0; j <= (int) symbols.size(); j++) {
nodes[j].cur = &partitions[j];
if (j > 0) {
nodes[j].prev = &nodes[j - 1];
}
if (j + 1 < nodes.size()) {
nodes[j].next = &nodes[j + 1];
}
nodes[j].id = j;
}
for (int j = 0; j < partitions.size() - 2; j++) {
partitions[j].second = GetRank(symbols, partitions, j, 0);
}
while (partitions.size() > 1) {
int min_rank = std::numeric_limits<int>::max();
int min_rank_idx = 0;
for (int j = 0; j < partitions.size() - 1; ++j) {
if (partitions[j].second < min_rank) {
min_rank = partitions[j].second;
min_rank_idx = j;
}
}
std::set <std::pair <int, int> > pq;
for (int j = 0; j < nodes.size(); j++) {
pq.insert(std::make_pair(nodes[j].cur->second, j));
}
int del = 0;
while (partitions.size() - del > 1) {
int min_rank = pq.begin()->first;
auto sel = &nodes[pq.begin()->second];

if (min_rank != std::numeric_limits<int>::max()) {
partitions[min_rank_idx].second = GetRank(symbols, partitions, min_rank_idx, 1);
if (min_rank_idx > 0) {
partitions[min_rank_idx - 1].second = GetRank(symbols, partitions, min_rank_idx - 1, 1);
pq.erase(std::make_pair(sel->cur->second, sel->id));
sel->cur->second = GetRank(symbols, sel, 1);
pq.insert(std::make_pair(sel->cur->second, sel->id));
if (sel->prev != nullptr) {
pq.erase(std::make_pair(sel->prev->cur->second, sel->prev->id));
sel->prev->cur->second = GetRank(symbols, sel->prev, 1);
pq.insert(std::make_pair(sel->prev->cur->second, sel->prev->id));
}
partitions.erase(partitions.begin() + min_rank_idx + 1);
pq.erase(std::make_pair(sel->next->cur->second, sel->next->id));
sel->next = sel->next->next;
sel->next->prev = sel;
del++;
} else {
break;
}
}
symbols.clear();
for (int j = 0; j < partitions.size() - 1; j++) {
std::string key = cur.substr(partitions[j].first, partitions[j + 1].first - partitions[j].first);
auto it = &nodes[0];
while (it != nullptr && it->next != nullptr) {
std::string key = cur.substr(it->cur->first, it->next->cur->first - it->cur->first);
v.push_back((float) stringToTokenDict[key]);
it = it->next;
}
symbols.clear();
}

std::string special = ori.substr(sep.back().first, sep.back().second);
Expand Down

0 comments on commit e8103fd

Please sign in to comment.