diff --git a/include/fastllm.h b/include/fastllm.h index 4305a51a..93e47eb1 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -332,6 +332,23 @@ namespace fastllm { void SetKVCache(); }; + struct PartitionLinkNode { + std::pair *cur = nullptr; + PartitionLinkNode *next = nullptr; + PartitionLinkNode *prev = nullptr; + int id = -1; + + PartitionLinkNode *Skip(int t) { + PartitionLinkNode *ret = this; + while (t--) { + if (ret != nullptr) { + ret = ret->next; + } + } + return ret; + } + }; + struct Tokenizer { enum TokenizerType { BPE = 0, @@ -415,6 +432,8 @@ namespace fastllm { void TryMergePairs(std::vector &symbols, int l, int r, std::priority_queue &q); // 插入备选symbol + int GetRank(std::vector &symbols, PartitionLinkNode *cur, int skip); + int GetRank(std::vector &symbols, std::vector> &partitions, int idx, int skip); void Insert(const std::string &s, int tokenId, float score = 1.0f); // 插入一个token diff --git a/src/fastllm.cpp b/src/fastllm.cpp index 8fe74484..c2fbbd71 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -1310,6 +1310,19 @@ namespace fastllm { q.push(SymbolPairs(now->score, l, r, symbols[l].len + symbols[r].len)); } + int Tokenizer::GetRank(std::vector &symbols, PartitionLinkNode *cur, int skip) { + auto nxt = cur->Skip(skip + 2); + if (nxt == nullptr) { + return std::numeric_limits::max(); + } + auto s = symbols[0].s + symbols[0].pos; + std::string key(s + cur->cur->first, s + nxt->cur->first); + if (stringToTokenDict.find(key) != stringToTokenDict.end()) { + return stringToTokenDict[key]; + } + return std::numeric_limits::max(); + } + int Tokenizer::GetRank(std::vector &symbols, std::vector> &partitions, int idx, int skip) { if (idx + skip + 2 >= partitions.size()) { return std::numeric_limits::max(); @@ -1614,42 +1627,63 @@ namespace fastllm { std::vector symbols; std::vector v; + for (int i = 0; i <= ori.size(); i++) { if (i == sep.back().first) { if (!symbols.empty()) { symbols.back().next = -1; std::string cur = ori.substr(i - symbols.size(), symbols.size()); std::vector> partitions(symbols.size() + 1); + std::vector nodes(symbols.size() + 1); for (int j = 0; j <= (int) symbols.size(); j++) { partitions[j] = std::make_pair(j, std::numeric_limits::max()); } + for (int j = 0; j <= (int) symbols.size(); j++) { + nodes[j].cur = &partitions[j]; + if (j > 0) { + nodes[j].prev = &nodes[j - 1]; + } + if (j + 1 < nodes.size()) { + nodes[j].next = &nodes[j + 1]; + } + nodes[j].id = j; + } for (int j = 0; j < partitions.size() - 2; j++) { partitions[j].second = GetRank(symbols, partitions, j, 0); } - while (partitions.size() > 1) { - int min_rank = std::numeric_limits::max(); - int min_rank_idx = 0; - for (int j = 0; j < partitions.size() - 1; ++j) { - if (partitions[j].second < min_rank) { - min_rank = partitions[j].second; - min_rank_idx = j; - } - } + std::set > pq; + for (int j = 0; j < nodes.size(); j++) { + pq.insert(std::make_pair(nodes[j].cur->second, j)); + } + int del = 0; + while (partitions.size() - del > 1) { + int min_rank = pq.begin()->first; + auto sel = &nodes[pq.begin()->second]; + if (min_rank != std::numeric_limits::max()) { - partitions[min_rank_idx].second = GetRank(symbols, partitions, min_rank_idx, 1); - if (min_rank_idx > 0) { - partitions[min_rank_idx - 1].second = GetRank(symbols, partitions, min_rank_idx - 1, 1); + pq.erase(std::make_pair(sel->cur->second, sel->id)); + sel->cur->second = GetRank(symbols, sel, 1); + pq.insert(std::make_pair(sel->cur->second, sel->id)); + if (sel->prev != nullptr) { + pq.erase(std::make_pair(sel->prev->cur->second, sel->prev->id)); + sel->prev->cur->second = GetRank(symbols, sel->prev, 1); + pq.insert(std::make_pair(sel->prev->cur->second, sel->prev->id)); } - partitions.erase(partitions.begin() + min_rank_idx + 1); + pq.erase(std::make_pair(sel->next->cur->second, sel->next->id)); + sel->next = sel->next->next; + sel->next->prev = sel; + del++; } else { break; } } - symbols.clear(); - for (int j = 0; j < partitions.size() - 1; j++) { - std::string key = cur.substr(partitions[j].first, partitions[j + 1].first - partitions[j].first); + auto it = &nodes[0]; + while (it != nullptr && it->next != nullptr) { + std::string key = cur.substr(it->cur->first, it->next->cur->first - it->cur->first); v.push_back((float) stringToTokenDict[key]); + it = it->next; } + symbols.clear(); } std::string special = ori.substr(sep.back().first, sep.back().second);