From 1ddde0ac0687fffc040f77fe98567e4fa524a474 Mon Sep 17 00:00:00 2001 From: yangzq50 <58433399+yangzq50@users.noreply.github.com> Date: Fri, 15 Nov 2024 21:08:28 +0800 Subject: [PATCH] Update rag analyzer (#2249) ### What problem does this PR solve? Support a new score function Support get topn result by dp Issue link:#2159 ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring --- src/common/analyzer/rag_analyzer.cpp | 299 ++++++++++++++++++-------- src/common/analyzer/rag_analyzer.cppm | 2 +- 2 files changed, 207 insertions(+), 94 deletions(-) diff --git a/src/common/analyzer/rag_analyzer.cpp b/src/common/analyzer/rag_analyzer.cpp index c058c72b53..c76500b291 100644 --- a/src/common/analyzer/rag_analyzer.cpp +++ b/src/common/analyzer/rag_analyzer.cpp @@ -25,6 +25,7 @@ module; #include #include #include +#include #include "string_utils.h" @@ -666,6 +667,8 @@ String RAGAnalyzer::RKey(const std::string_view line) { return reversed; } +#define DIVIDE_F_BY_N 1 + Pair, double> RAGAnalyzer::Score(const Vector> &token_freqs) { constexpr i64 B = 30; i64 F = 0, L = 0; @@ -676,7 +679,11 @@ Pair, double> RAGAnalyzer::Score(const Vector> L += (UTF8Length(token) < 2) ? 0 : 1; tokens.push_back(token); } +#ifdef DIVIDE_F_BY_N const auto score = (B + L + F) / static_cast(tokens.size()); +#else + const auto score = F + (B + L) / static_cast(tokens.size()); +#endif return {std::move(tokens), score}; } @@ -815,43 +822,84 @@ int RAGAnalyzer::DFS(const String &chars, } struct BestTokenCandidate { +#ifdef DIVIDE_F_BY_N u32 token_num{}; i64 score_sum{}; +#else + // N: token num + // L: num of tokens with length >= 2 + // F: sum of freq + Pair N_L{}; + i64 F{}; +#endif Vector tokens{}; }; -struct GrowingBestTokenCandidates { +template +struct Overload : Fs... { + using Fs::operator()...; +}; + +// explicit deduction guide +template +Overload(Fs...) -> Overload; + +struct GrowingBestTokenCandidatesTopN { + const i32 top_n{}; Vector candidates{}; - void AddBestTokenCandidate(const u32 tn, const i64 ss, const Vector &tks_old_first, const std::string_view wait_append) { - const auto it = - std::lower_bound(candidates.begin(), - candidates.end(), - tn, - [](const BestTokenCandidate &a, const u32 x) { - return a.token_num < x; - }); - const bool it_tn_same = (it != candidates.end() && it->token_num == tn); - if (it_tn_same && it->score_sum >= ss) { - return; + explicit GrowingBestTokenCandidatesTopN(const i32 top_n) : top_n(top_n) {} + +#ifdef DIVIDE_F_BY_N + void AddBestTokenCandidateTopN(const u32 tn, const i64 ss, const Vector &tks_old_first, const std::string_view wait_append) { + const auto e_r_comp = Overload{[](const BestTokenCandidate &a, const u32 x) { return a.token_num < x; }, + [](const u32 x, const BestTokenCandidate &a) { return x < a.token_num; }}; + const auto min_comp = [](const BestTokenCandidate &a, const BestTokenCandidate &b) { return a.score_sum < b.score_sum; }; + const auto [it_b, it_e] = std::equal_range(candidates.begin(), candidates.end(), tn, e_r_comp); +#else + void AddBestTokenCandidateTopN(const Pair n_l, + const i64 new_f, + const Vector &tks_old_first, + const std::string_view wait_append) { + const auto e_r_comp = Overload{[](const BestTokenCandidate &a, const Pair x) { return a.N_L < x; }, + [](const Pair x, const BestTokenCandidate &a) { return x < a.N_L; }}; + const auto min_comp = [](const BestTokenCandidate &a, const BestTokenCandidate &b) { return a.F < b.F; }; + const auto [it_b, it_e] = std::equal_range(candidates.begin(), candidates.end(), n_l, e_r_comp); +#endif + auto target_it = it_b; + bool do_replace = false; + if (const auto match_cnt = std::distance(it_b, it_e); match_cnt >= top_n) { + assert(match_cnt == top_n); + const auto it = std::min_element(it_b, it_e, min_comp); +#ifdef DIVIDE_F_BY_N + if (it->score_sum >= ss) { +#else + if (it->F >= new_f) { +#endif + return; + } + target_it = it; + do_replace = true; } +#ifdef DIVIDE_F_BY_N BestTokenCandidate candidate = {tn, ss}; +#else + BestTokenCandidate candidate = {n_l, new_f}; +#endif candidate.tokens.reserve(tks_old_first.size() + 1); candidate.tokens.insert(candidate.tokens.end(), tks_old_first.begin(), tks_old_first.end()); candidate.tokens.push_back(wait_append); - if (it_tn_same) { - *it = std::move(candidate); + if (do_replace) { + *target_it = std::move(candidate); } else { - candidates.insert(it, std::move(candidate)); + candidates.insert(target_it, std::move(candidate)); } } }; -constexpr i64 BASE_SCORE_SUM = 30; - -Pair, double> RAGAnalyzer::GetBestTokens(const std::string_view chars) const { +Vector, double>> RAGAnalyzer::GetBestTokensTopN(const std::string_view chars, const u32 n) const { const auto utf8_len = UTF8Length(chars); - Vector dp_vec(utf8_len + 1); + Vector dp_vec(utf8_len + 1, GrowingBestTokenCandidatesTopN(n)); dp_vec[0].candidates.resize(1); const char *current_utf8_ptr = chars.data(); u32 current_left_chars = chars.size(); @@ -876,22 +924,34 @@ Pair, double> RAGAnalyzer::GetBestTokens(const std::str lookup_until += next_one_utf8.size(); lookup_left_chars -= next_one_utf8.size(); } - auto update_dp_vec = [&dp_vec, i, j, original_sv=std::string_view{current_utf8_ptr, growing_key.size()}](const i32 key_score) { - auto &target_dp = dp_vec[i + j]; - for (const auto &[tn, ss, v] : dp_vec[i].candidates) { - target_dp.AddBestTokenCandidate(tn + 1, ss + key_score, v, original_sv); + auto dp_f = [&dp_vec, i, j, original_sv = std::string_view{current_utf8_ptr, growing_key.size()}](const i32 key_f, const u32 add_l) { +#ifdef DIVIDE_F_BY_N + const i32 key_score = key_f + add_l; + for (auto &target_dp = dp_vec[i + j]; const auto &[tn, ss, v] : dp_vec[i].candidates) { + target_dp.AddBestTokenCandidateTopN(tn + 1, ss + key_score, v, original_sv); + } +#else + auto get_add_n_l = [add_l](Pair old_n_l) { + ++old_n_l.first; + old_n_l.second += add_l; + return old_n_l; + }; + for (auto &target_dp = dp_vec[i + j]; const auto &[old_n_l, old_f, old_v] : dp_vec[i].candidates) { + target_dp.AddBestTokenCandidateTopN(get_add_n_l(old_n_l), old_f + key_f, old_v, original_sv); } +#endif }; if (const auto traverse_result = trie_->Traverse(growing_key.data(), reuse_node_pos, reuse_key_pos, growing_key.size()); traverse_result >= 0) { // in dictionary - const auto key_score = DecodeFreq(traverse_result) + static_cast(j >= 2); - update_dp_vec(key_score); + const i32 key_f = DecodeFreq(traverse_result); + const auto add_l = static_cast(j >= 2); + dp_f(key_f, add_l); } else { // not in dictionary if (j == 1) { // also give a score: -12 - update_dp_vec(-12); + dp_f(-12, 0); } if (traverse_result == -2) { // no more results @@ -904,17 +964,114 @@ Pair, double> RAGAnalyzer::GetBestTokens(const std::str current_utf8_ptr += forward_cnt; current_left_chars -= forward_cnt; } - Pair, double> result; - result.second = std::numeric_limits::lowest(); + Vector, double>> result; + result.reserve(n); + constexpr i64 B = 30; +#ifdef DIVIDE_F_BY_N for (auto &[token_num, score_sum, tokens] : dp_vec.back().candidates) { - if (const auto score = static_cast(BASE_SCORE_SUM + score_sum) / token_num; score > result.second) { - result.first = std::move(tokens); - result.second = score; + auto new_pair = std::make_pair(std::move(tokens), (static_cast(B + score_sum) / token_num)); +#else + for (auto &[N_L, F, tokens] : dp_vec.back().candidates) { + auto new_pair = std::make_pair(std::move(tokens), (F + (static_cast(B + N_L.second) / N_L.first))); +#endif + if (result.size() < n) { + result.push_back(std::move(new_pair)); + } else { + assert(result.size() == n); + if (new_pair.second > result.back().second) { + result.pop_back(); + const auto insert_pos = + std::lower_bound(result.begin(), result.end(), new_pair, [](const auto &a, const auto &b) { return a.second > b.second; }); + result.insert(insert_pos, std::move(new_pair)); + } } } return result; } +// TODO: for test +// #ifndef INFINITY_DEBUG +// #define INFINITY_DEBUG 1 +// #endif + +#ifdef INFINITY_DEBUG +namespace dp_debug { +template +String TestPrintTokens(const Vector &tokens) { + std::ostringstream oss; + for (std::size_t i = 0; i < tokens.size(); ++i) { + oss << (i ? " #" : "#") << tokens[i] << "#"; + } + return std::move(oss).str(); +} + +auto print_1 = [](const bool b) { return b ? "✅" : "❌"; }; +auto print_2 = [](const bool b) { return b ? "equal" : "not equal"; }; + +void compare_score_and_tokens(const Vector &dfs_tokens, + const double dfs_score, + const Vector &dp_tokens, + const double dp_score, + const String &prefix) { + std::ostringstream oss; + const auto b_score_eq = dp_score == dfs_score; + oss << std::format("\n{} {} DFS and DP score {}:\nDFS: {}\nDP : {}\n", print_1(b_score_eq), prefix, print_2(b_score_eq), dfs_score, dp_score); + bool vec_equal = true; + if (dp_tokens.size() != dfs_tokens.size()) { + vec_equal = false; + } else { + for (std::size_t k = 0; k < dp_tokens.size(); ++k) { + if (dp_tokens[k] != dfs_tokens[k]) { + vec_equal = false; + break; + } + } + } + oss << std::format("{} {} DFS and DP result {}:\nDFS: {}\nDP : {}\n", + print_1(vec_equal), + prefix, + print_2(vec_equal), + TestPrintTokens(dfs_tokens), + TestPrintTokens(dp_tokens)); + std::cerr << std::move(oss).str() << std::endl; +} + +inline void CheckDP(const RAGAnalyzer *this_ptr, + const std::string_view input_str, + const Vector &dfs_tokens, + const double dfs_score, + const auto t0, + const auto t1) { + const auto dp_result = this_ptr->GetBestTokensTopN(input_str, 1); + const auto t2 = std::chrono::high_resolution_clock::now(); + const auto dfs_duration = std::chrono::duration_cast>(t1 - t0); + const auto dp_duration = std::chrono::duration_cast>(t2 - t1); + const auto dp_faster = dp_duration < dfs_duration; + std::cerr << "\n!!! " << print_1(dp_faster) << "\nTOP1 DFS duration: " << dfs_duration << " \nDP duration: " << dp_duration; + const auto &[dp_vec, dp_score] = dp_result[0]; + compare_score_and_tokens(dfs_tokens, dfs_score, dp_vec, dp_score, "[1 in top1]"); +} + +inline void CheckDP2(const RAGAnalyzer *this_ptr, const std::string_view input_str, auto get_dfs_sorted_tokens, const auto t0, const auto t1) { + constexpr int topn = 2; + const auto dp_result = this_ptr->GetBestTokensTopN(input_str, topn); + const auto t2 = std::chrono::high_resolution_clock::now(); + const auto dfs_duration = std::chrono::duration_cast>(t1 - t0); + const auto dp_duration = std::chrono::duration_cast>(t2 - t1); + const auto dp_faster = dp_duration < dfs_duration; + std::cerr << "\n!!! " << print_1(dp_faster) << "\nTOP2 DFS duration: " << dfs_duration << " \nTOP2 DP duration: " << dp_duration; + const auto dfs_sorted_tokens = get_dfs_sorted_tokens(); + for (int i = 0; i < topn; ++i) { + compare_score_and_tokens(dfs_sorted_tokens[i].first, + dfs_sorted_tokens[i].second, + dp_result[i].first, + dp_result[i].second, + std::format("[{} in top{}]", i + 1, topn)); + } +} +} +#endif + String RAGAnalyzer::Merge(const String &tks_str) const { String tks = tks_str; @@ -959,62 +1116,6 @@ void RAGAnalyzer::EnglishNormalize(const Vector &tokens, Vector } } -// TODO: for test -// #ifndef INFINITY_DEBUG -// #define INFINITY_DEBUG 1 -// #endif - -#ifdef INFINITY_DEBUG -template -String TestPrintTokens(const Vector &tokens) { - std::ostringstream oss; - for (std::size_t i = 0; i < tokens.size(); ++i) { - oss << (i ? " #" : "#"); - oss << tokens[i]; - oss << "#"; - } - return std::move(oss).str(); -} - -inline void CheckDP(const RAGAnalyzer *this_ptr, - const std::string_view input_str, - const Vector &dfs_tokens, - const double dfs_score, - const auto t0, - const auto t1) { - const auto [dp_vec, dp_score] = this_ptr->GetBestTokens(input_str); - const auto t2 = std::chrono::high_resolution_clock::now(); - const auto dfs_duration = std::chrono::duration_cast>(t1 - t0); - const auto dp_duration = std::chrono::duration_cast>(t2 - t1); - auto print_1 = [](const bool b) { - return b ? "✅✅✅" : "❌❌❌"; - }; - auto print_2 = [](const bool b) { - return b ? "" : " not"; - }; - const auto dp_faster = dp_duration < dfs_duration; - std::cerr << "\n!!! " << print_1(dp_faster) << "\nDFS duration: " << dfs_duration << " \nDP duration: " << dp_duration; - const auto b_score_eq = dp_score == dfs_score; - std::cerr << std::format("\n{} DFS and DP score{} equal:\nDFS: {}\nDP : {}\n", print_1(b_score_eq), print_2(b_score_eq), dfs_score, dp_score); - bool vec_equal = true; - if (dp_vec.size() != dfs_tokens.size()) { - vec_equal = false; - } else { - for (std::size_t k = 0; k < dp_vec.size(); ++k) { - if (dp_vec[k] != dfs_tokens[k]) { - vec_equal = false; - break; - } - } - } - std::cerr << std::format("{} DFS and DP result{} equal:\nDFS: {}\nDP : {}\n", - print_1(vec_equal), - print_2(vec_equal), - TestPrintTokens(dfs_tokens), - TestPrintTokens(dp_vec)); -} -#endif - void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { auto [tks, s] = MaxForward(L); auto [tks1, s1] = MaxBackward(L); @@ -1049,7 +1150,7 @@ void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { Vector> pre_tokens; Vector>> token_list; Vector best_tokens; - double max_score = -100.0F; + double max_score = std::numeric_limits::lowest(); const auto str_for_dfs = Join(tks, _j, j, ""); #ifdef INFINITY_DEBUG const auto t0 = std::chrono::high_resolution_clock::now(); @@ -1057,7 +1158,7 @@ void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); #ifdef INFINITY_DEBUG const auto t1 = std::chrono::high_resolution_clock::now(); - CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); + dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); #endif res.push_back(Join(best_tokens, 0)); @@ -1074,7 +1175,7 @@ void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { Vector> pre_tokens; Vector>> token_list; Vector best_tokens; - double max_score = -100.0F; + double max_score = std::numeric_limits::lowest(); const auto str_for_dfs = Join(tks, _j, tks.size(), ""); #ifdef INFINITY_DEBUG const auto t0 = std::chrono::high_resolution_clock::now(); @@ -1082,7 +1183,7 @@ void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); #ifdef INFINITY_DEBUG const auto t1 = std::chrono::high_resolution_clock::now(); - CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); + dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); #endif res.push_back(Join(best_tokens, 0)); } @@ -1121,7 +1222,7 @@ void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { Vector> pre_tokens; Vector>> token_list; Vector best_tokens; - double max_score = -100.0F; + double max_score = std::numeric_limits::lowest(); const auto str_for_dfs = Join(tks, s, e < tks.size() ? e + 1 : e, ""); #ifdef INFINITY_DEBUG const auto t0 = std::chrono::high_resolution_clock::now(); @@ -1129,7 +1230,7 @@ void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); #ifdef INFINITY_DEBUG const auto t1 = std::chrono::high_resolution_clock::now(); - CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); + dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); #endif // Vector, double>> sorted_tokens; // SortTokens(token_list, sorted_tokens); @@ -1287,7 +1388,19 @@ void RAGAnalyzer::FineGrainedTokenize(const String &tokens, Vector &resu Vector> pre_tokens; Vector best_tokens; double max_score = 0.0F; +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif DFS(token, 0, pre_tokens, token_list, best_tokens, max_score, true); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + auto get_dfs_sorted_tokens = [&]() { + Vector, double>> sorted_tokens; + SortTokens(token_list, sorted_tokens); + return sorted_tokens; + }; + dp_debug::CheckDP2(this, token, get_dfs_sorted_tokens, t0, t1); +#endif } if (token_list.size() < 2) { res.push_back(token); diff --git a/src/common/analyzer/rag_analyzer.cppm b/src/common/analyzer/rag_analyzer.cppm index 640bba5632..12168b9b28 100644 --- a/src/common/analyzer/rag_analyzer.cppm +++ b/src/common/analyzer/rag_analyzer.cppm @@ -90,7 +90,7 @@ private: void EnglishNormalize(const Vector &tokens, Vector &res); public: - Pair, double> GetBestTokens(std::string_view chars) const; + Vector, double>> GetBestTokensTopN(std::string_view chars, u32 n) const; static const SizeT term_string_buffer_limit_ = 4096 * 3;