Skip to content

Commit

Permalink
Update rag_analyzer.cpp (#2273)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Improve performance of `RAGAnalyzer::GetBestTokensTopN`, reduce memory
cost

Issue link:#2159

### Type of change

- [x] Refactoring
- [x] Performance Improvement
  • Loading branch information
yangzq50 authored Nov 20, 2024
1 parent 6f695bc commit 7b3ba74
Showing 1 changed file with 73 additions and 80 deletions.
153 changes: 73 additions & 80 deletions src/common/analyzer/rag_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,78 +821,64 @@ int RAGAnalyzer::DFS(const String &chars,
return DFS(chars, s + 1, pre_tokens, token_list, best_tokens, max_score, memo_all);
}

struct TokensList {
const TokensList *prev = nullptr;
std::string_view token = {};
};

struct BestTokenCandidate {
#ifdef DIVIDE_F_BY_N
u32 token_num{};
i64 score_sum{};
#else
static constexpr i64 B = 30;
TokensList tl{};
// N: token num
// L: num of tokens with length >= 2
// F: sum of freq
Pair<u32, u32> N_L{};
u32 N{};
u32 L{};
i64 F{};
auto k() const {
#ifdef DIVIDE_F_BY_N
return N;
#else
return std::make_pair(N, L);
#endif
Vector<std::string_view> tokens{};
};

template <class... Fs>
struct Overload : Fs... {
using Fs::operator()...;
}
auto v() const { return F; }
auto score() const {
#ifdef DIVIDE_F_BY_N
return static_cast<double>(B + L + F) / N;
#else
return F + (static_cast<double>(B + L) / N);
#endif
}
BestTokenCandidate update(const std::string_view new_token_sv, const i32 key_f, const u32 add_l) const {
return {{&tl, new_token_sv}, N + 1, L + add_l, F + key_f};
}
};

// explicit deduction guide
template <class... Fs>
Overload(Fs...) -> Overload<Fs...>;

struct GrowingBestTokenCandidatesTopN {
const i32 top_n{};
i32 top_n{};
Vector<BestTokenCandidate> candidates{};

explicit GrowingBestTokenCandidatesTopN(const i32 top_n) : top_n(top_n) {}

#ifdef DIVIDE_F_BY_N
void AddBestTokenCandidateTopN(const u32 tn, const i64 ss, const Vector<std::string_view> &tks_old_first, const std::string_view wait_append) {
const auto e_r_comp = Overload{[](const BestTokenCandidate &a, const u32 x) { return a.token_num < x; },
[](const u32 x, const BestTokenCandidate &a) { return x < a.token_num; }};
const auto min_comp = [](const BestTokenCandidate &a, const BestTokenCandidate &b) { return a.score_sum < b.score_sum; };
const auto [it_b, it_e] = std::equal_range(candidates.begin(), candidates.end(), tn, e_r_comp);
#else
void AddBestTokenCandidateTopN(const Pair<u32, u32> n_l,
const i64 new_f,
const Vector<std::string_view> &tks_old_first,
const std::string_view wait_append) {
const auto e_r_comp = Overload{[](const BestTokenCandidate &a, const Pair<u32, u32> x) { return a.N_L < x; },
[](const Pair<u32, u32> x, const BestTokenCandidate &a) { return x < a.N_L; }};
const auto min_comp = [](const BestTokenCandidate &a, const BestTokenCandidate &b) { return a.F < b.F; };
const auto [it_b, it_e] = std::equal_range(candidates.begin(), candidates.end(), n_l, e_r_comp);
#endif
void AddBestTokenCandidateTopN(const BestTokenCandidate &add_candidate) {
const auto [it_b, it_e] =
std::equal_range(candidates.begin(), candidates.end(), add_candidate, [](const auto &a, const auto &b) { return a.k() < b.k(); });
auto target_it = it_b;
bool do_replace = false;
if (const auto match_cnt = std::distance(it_b, it_e); match_cnt >= top_n) {
assert(match_cnt == top_n);
const auto it = std::min_element(it_b, it_e, min_comp);
#ifdef DIVIDE_F_BY_N
if (it->score_sum >= ss) {
#else
if (it->F >= new_f) {
#endif
const auto it = std::min_element(it_b, it_e, [](const auto &a, const auto &b) { return a.v() < b.v(); });
if (it->v() >= add_candidate.v()) {
return;
}
target_it = it;
do_replace = true;
}
#ifdef DIVIDE_F_BY_N
BestTokenCandidate candidate = {tn, ss};
#else
BestTokenCandidate candidate = {n_l, new_f};
#endif
candidate.tokens.reserve(tks_old_first.size() + 1);
candidate.tokens.insert(candidate.tokens.end(), tks_old_first.begin(), tks_old_first.end());
candidate.tokens.push_back(wait_append);
if (do_replace) {
*target_it = std::move(candidate);
*target_it = add_candidate;
} else {
candidates.insert(target_it, std::move(candidate));
candidates.insert(target_it, add_candidate);
}
}
};
Expand Down Expand Up @@ -925,21 +911,9 @@ Vector<Pair<Vector<std::string_view>, double>> RAGAnalyzer::GetBestTokensTopN(co
lookup_left_chars -= next_one_utf8.size();
}
auto dp_f = [&dp_vec, i, j, original_sv = std::string_view{current_utf8_ptr, growing_key.size()}](const i32 key_f, const u32 add_l) {
#ifdef DIVIDE_F_BY_N
const i32 key_score = key_f + add_l;
for (auto &target_dp = dp_vec[i + j]; const auto &[tn, ss, v] : dp_vec[i].candidates) {
target_dp.AddBestTokenCandidateTopN(tn + 1, ss + key_score, v, original_sv);
for (auto &target_dp = dp_vec[i + j]; const auto &c : dp_vec[i].candidates) {
target_dp.AddBestTokenCandidateTopN(c.update(original_sv, key_f, add_l));
}
#else
auto get_add_n_l = [add_l](Pair<u32, u32> old_n_l) {
++old_n_l.first;
old_n_l.second += add_l;
return old_n_l;
};
for (auto &target_dp = dp_vec[i + j]; const auto &[old_n_l, old_f, old_v] : dp_vec[i].candidates) {
target_dp.AddBestTokenCandidateTopN(get_add_n_l(old_n_l), old_f + key_f, old_v, original_sv);
}
#endif
};
if (const auto traverse_result = trie_->Traverse(growing_key.data(), reuse_node_pos, reuse_key_pos, growing_key.size());
traverse_result >= 0) {
Expand All @@ -964,27 +938,46 @@ Vector<Pair<Vector<std::string_view>, double>> RAGAnalyzer::GetBestTokensTopN(co
current_utf8_ptr += forward_cnt;
current_left_chars -= forward_cnt;
}
Vector<Pair<Vector<std::string_view>, double>> result;
result.reserve(n);
constexpr i64 B = 30;
#ifdef DIVIDE_F_BY_N
for (auto &[token_num, score_sum, tokens] : dp_vec.back().candidates) {
auto new_pair = std::make_pair(std::move(tokens), (static_cast<double>(B + score_sum) / token_num));
#else
for (auto &[N_L, F, tokens] : dp_vec.back().candidates) {
auto new_pair = std::make_pair(std::move(tokens), (F + (static_cast<double>(B + N_L.second) / N_L.first)));
#endif
if (result.size() < n) {
result.push_back(std::move(new_pair));
Vector<Pair<const TokensList *, double>> mid_result;
mid_result.reserve(n);
for (const auto &c : dp_vec.back().candidates) {
const auto new_pair = std::make_pair(&(c.tl), c.score());
if (mid_result.size() < n) {
mid_result.push_back(new_pair);
} else {
assert(result.size() == n);
if (new_pair.second > result.back().second) {
result.pop_back();
const auto insert_pos =
std::lower_bound(result.begin(), result.end(), new_pair, [](const auto &a, const auto &b) { return a.second > b.second; });
result.insert(insert_pos, std::move(new_pair));
assert(mid_result.size() == n);
if (new_pair.second > mid_result.back().second) {
mid_result.pop_back();
const auto insert_pos = std::lower_bound(mid_result.begin(), mid_result.end(), new_pair, [](const auto &a, const auto &b) {
return a.second > b.second;
});
mid_result.insert(insert_pos, new_pair);
}
}
}
class HelperFunc {
u32 cnt = 0;
Vector<std::string_view> result{};
void GetTokensInner(const TokensList *tl) {
if (!tl->prev) {
result.reserve(cnt);
return;
}
++cnt;
GetTokensInner(tl->prev);
result.push_back(tl->token);
}

public:
Vector<std::string_view> GetTokens(const TokensList *tl) {
GetTokensInner(tl);
return std::move(result);
}
};
Vector<Pair<Vector<std::string_view>, double>> result;
result.reserve(mid_result.size());
for (const auto [tl, score] : mid_result) {
result.emplace_back(HelperFunc{}.GetTokens(tl), score);
}
return result;
}
Expand Down

0 comments on commit 7b3ba74

Please sign in to comment.