diff --git a/tests/unit_tests/README.md b/tests/unit_tests/README.md new file mode 100644 index 0000000..f2ef135 --- /dev/null +++ b/tests/unit_tests/README.md @@ -0,0 +1,9 @@ +For tests execution simply run: +``` +pip install pytest +pytest +``` +Testing may take several minutes. + + + diff --git a/tests/unit_tests/stress_test.cpp b/tests/unit_tests/stress_test.cpp new file mode 100644 index 0000000..1b66b17 --- /dev/null +++ b/tests/unit_tests/stress_test.cpp @@ -0,0 +1,449 @@ +#include +#include +#include +#include +#include +#include +#include +#include "stress_test.h" + +#include "../../youtokentome/cpp/utils.h" +#include "../../youtokentome/cpp/bpe.h" +#include "../../youtokentome/cpp/utf8.h" + +#include +#include + +namespace vkcom { +using namespace std; + +extern int alive_tokens; + +using char32=uint32_t; + +BPEState learn_bpe_slow(const string &text_utf8, int n_token, string, BpeConfig bpe_config) { + auto row_data = decode_utf8(text_utf8.data(), text_utf8.data() + text_utf8.size()); + vector> splited_text; + for (auto &ch: row_data) { + if (is_space(ch)) { + ch = SPACE_TOKEN; + } + } + for (; !row_data.empty() && is_space(row_data.back()); row_data.pop_back()); + ska::flat_hash_set removed_chars; + auto char2id = compute_alphabet(row_data, removed_chars, bpe_config); + remove_rare_chars(row_data, removed_chars); + ska::flat_hash_map id2char; + for (auto x: char2id) { + id2char[x.second] = x.first; + } + int used_ids = bpe_config.special_tokens.n_special_tokens() + char2id.size(); + + for (int i = 0; i < (int) row_data.size();) { + for (; i < (int) row_data.size() && is_space(row_data[i]); i++); + if (i == (int) row_data.size()) { + break; + } + splited_text.emplace_back(); + splited_text.back().push_back(SPACE_TOKEN); + for (; i < (int) row_data.size() && !is_space(row_data[i]); i++) { + if (char2id.count(row_data[i])) { + splited_text.back().push_back(row_data[i]); + } + } + } + vector> coded; + + for (const auto &v: splited_text) { + coded.emplace_back(); + for (auto ch: v) { + coded.back().push_back(char2id[ch]); + } + } + + map> recipe; + map recipe_s; + for (int i = 2; i < used_ids; i++) { + recipe[i] = {i}; + recipe_s[i] = encode_utf8({id2char[i]}); + } + + auto get_recipe = [&](int x, int y) { + assert(recipe.count(x)); + assert(recipe.count(y)); + vector target_recipe; + for (auto token_id: recipe[x]) target_recipe.push_back(token_id); + for (auto token_id: recipe[y]) target_recipe.push_back(token_id); + return target_recipe; + }; + + struct Candidate { + uint32_t x, y; + int cnt; + bool operator<(const Candidate &other) const { + if (cnt != other.cnt) { + return cnt < other.cnt; + } + auto this_mn = min(x, y); + auto this_mx = max(x, y); + + auto other_mn = min(other.x, other.y); + auto other_mx = max(other.x, other.y); + + if (this_mx != other_mx) { + return this_mx > other_mx; + } + if (this_mn != other_mn) { + return this_mn > other_mn; + } + return x < other.x; + } + }; + + vector rules; + + for (; used_ids < n_token;) { + map, int> local_cnt; + + for (const auto &v: coded) { + for (int i = 0; i < (int) v.size() - 1; i++) { + local_cnt[{v[i], v[i + 1]}]++; + if (v[i] == v[i + 1] && i + 2 < (int) v.size() && v[i] == v[i + 2]) { + i++; + } + } + } + + Candidate best = {0, 0, -1}; + + for (auto cand: local_cnt) { + uint32_t x = cand.first.first; + uint32_t y = cand.first.second; + Candidate cur = {x, y, cand.second}; + if (best < cur) { + best = cur; + } + } + + if (best.cnt == -1) { + break; + } + uint32_t z = used_ids++; + rules.push_back({best.x, best.y, z}); + + recipe[z] = get_recipe(best.x, best.y); + recipe_s[z] = recipe_s[best.x] + recipe_s[best.y]; + + for (auto &v: coded) { + for (int i = 0; i < (int) v.size() - 1; i++) { + if (v[i] == static_cast(best.x) && v[i + 1] == static_cast(best.y)) { + v[i] = z; + v.erase(v.begin() + i + 1); + } + } + } + } + + BPEState state = {char2id, rules, bpe_config.special_tokens}; + return state; +} + +DecodeResult decode_slow(const string &text_utf8, const BaseEncoder &bpe_applyer) { + + const auto &char2id = bpe_applyer.bpe_state.char2id; + const auto &id2char = bpe_applyer.id2char; + const auto &rules = bpe_applyer.bpe_state.rules; + const auto &recipe = bpe_applyer.recipe; + + auto text = decode_utf8(text_utf8.data(), text_utf8.data() + text_utf8.size()); + for (auto &ch: text) { + if (is_space(ch)) { + ch = SPACE_TOKEN; + } + } + + for (; !text.empty() && text.back() == SPACE_TOKEN; text.pop_back()); + + struct Node { + uint32_t val; + string new_chars; + }; + + vector> words; + for (int i = 0; i < (int) text.size();) { + for (; i < (int) text.size() && is_space(text[i]); i++); + if (i == (int) text.size()) { + break; + } + + words.emplace_back(); + words.back().push_back({char2id.at(SPACE_TOKEN), {}}); + for (; i < (int) text.size() && !is_space(text[i]);) { + + if (char2id.count(text[i]) == 0) { + int cur = i; + for (; i < (int) text.size() && !is_space(text[i]) && char2id.count(text[i]) == 0; i++); + words.back().push_back({static_cast(bpe_applyer.bpe_state.special_tokens.unk_id), + encode_utf8({text.begin() + cur, text.begin() + i})}); + } else { + words.back().push_back({char2id.at(text[i]), {}}); + i++; + } + } + } + + for (auto rule: rules) { + for (auto &v: words) { + for (int i = 0; i + 1 < (int) v.size(); i++) { + if (v[i].val == rule.x && v[i + 1].val == rule.y) { + v[i].val = rule.z; + v.erase(v.begin() + i + 1); + } + } + } + } + + vector ids; + vector pieces; + for (auto &v: words) { + for (const auto &u: v) { + ids.push_back(u.val); + if (static_cast(u.val) == bpe_applyer.bpe_state.special_tokens.unk_id) { + pieces.push_back(u.new_chars); + } else { + auto recipe_u = recipe.at(u.val); + vector recipe_u_utf8; + for (auto ch: recipe_u) { + assert(id2char.count(ch)); + recipe_u_utf8.push_back(id2char.at(ch)); + } + pieces.push_back(encode_utf8(recipe_u_utf8)); + } + } + } + + return {ids, pieces}; +} + +string generate_text(int n_limit, bool flag_train) { + string sigma = flag_train ? "abc " : "abcd "; + vector a; + int n = rand() % 1000 + 1; + n = min(n, n_limit); + string row_data; + row_data.push_back(sigma[0]); + + auto add_char = [&](char ch) { + row_data.push_back(ch); + }; + + for (; (int) row_data.size() < n;) { + if (rand() % 2) { + add_char(sigma[rand() % sigma.size()]); + } else { + int l = rand() % 5 + 2; + int seg = rand() % 4 + 1; + vector tmp; + for (int i = 0; i < seg; i++) { + add_char(sigma[rand() % sigma.size()]); + } + for (int i = 0; i < l; i++) { + for (auto ch: tmp) { + add_char(ch); + } + } + } + } + if ((int) row_data.size() > n) { + row_data.resize(n); + } + for (; !row_data.empty() && is_space(row_data.back()); row_data.pop_back()); + for (; (int) row_data.size() < n;) { + row_data.push_back(sigma[0]); + } + assert(static_cast(row_data.size()) >= n); + return row_data; + +} + +void manual_test() { + string trn_data = "baba baaab"; + string inf_data = "d d"; + int n_tokens = 2 + 2 + 5; + + auto trn_data_copy = trn_data; + SpecialTokens special_tokens_config = {0, 1, 2, 3}; + BpeConfig bpe_config = {1.0, 1, special_tokens_config}; + + auto model_fast = learn_bpe_from_string(trn_data_copy, n_tokens, "remove_it.txt", bpe_config); + auto model_slow = learn_bpe_slow(trn_data, n_tokens, "remove_it.txt", bpe_config); + assert(model_fast.rules == model_slow.rules); + assert(model_fast.char2id == model_slow.char2id); + + BaseEncoder applyer(model_fast, 1); + auto ids = applyer.encode_as_ids({inf_data})[0]; + auto result_slow = decode_slow(inf_data, applyer); + assert(ids == result_slow.ids); +} + +vector to_no_space_tokens(string raw_string) { + auto tokens = decode_utf8(raw_string.data(), raw_string.data() + raw_string.size()); + int cur = 0; + for (auto ch: tokens) { + if (!is_space(ch)) { + tokens[cur++] = ch; + } + } + tokens.resize(cur); + return tokens; +} + +void parallel_test(int n_iter, int n_threads) { + for (int i = 0; i < n_iter; i++) { + srand(i); + int test_size = 1000; + auto train_data = generate_text(test_size, true); + int n_sentences = 1000; + vector inference_data; + for (int i = 0; i < n_sentences; i++) { + inference_data.push_back(generate_text(20, false)); + } + set unique_input_chars(train_data.begin(), train_data.end()); + int vocab_size = unique_input_chars.size() + 4 + rand() % 40; + double character_coverage = 1 - (rand() * 1.0 / RAND_MAX) * 0.4; + if (rand() % 2 == 0) { + character_coverage = 1; + } + + auto train_data_copy = train_data; + BpeConfig bpe_config = {character_coverage, n_threads, {0, 1, 2, 3}}; + auto learned_model = learn_bpe_from_string(train_data_copy, vocab_size, "remove_it.txt", bpe_config); + BaseEncoder applyer(learned_model, 20); + + vector> result_sentence_by_sentence; + for (auto s: inference_data) { + result_sentence_by_sentence.push_back(applyer.encode_as_subwords({s})[0]); + } + auto result_parallel = applyer.encode_as_subwords(inference_data); + assert(result_sentence_by_sentence == result_parallel); + } +} + +void base_stress(int n_iter) { + int n_threads = 8; + const int NUMBER_OF_SPECIAL_TOKENS_LOCAL = 4; + for (int it = 0; it != n_iter; it++) { + srand(it); + cerr << "-------------------- new test " << it << " --------------- " << endl; + int test_size = 1000; + + auto train_data = generate_text(test_size, true); + set unique_train_symbols(train_data.begin(), train_data.end()); + unique_train_symbols.insert(' '); + int vocab_size = unique_train_symbols.size() + NUMBER_OF_SPECIAL_TOKENS_LOCAL + rand() % 40; + + cerr << "train_data: !" << train_data << "! (vocab_size, len): (" << vocab_size << ", " << train_data.size() + << ")" << endl; + + double character_coverage = 1 - (rand() * 1.0 / RAND_MAX) * 0.4; + if (rand() % 2 == 0) { + character_coverage = 1; + } + auto train_data_copy = train_data; + BpeConfig bpe_config = {character_coverage, n_threads, {0, 1, 2, 3}}; + auto fast_solution_model = learn_bpe_from_string(train_data_copy, vocab_size, "remove_it.txt", bpe_config); + auto slow_solution_model = learn_bpe_slow(train_data, vocab_size, "remove_it.txt", bpe_config); + + if (fast_solution_model.rules != slow_solution_model.rules + || fast_solution_model.char2id != slow_solution_model.char2id) { + for (auto rr: {fast_solution_model, slow_solution_model}) { + cerr << "rules: " << endl; + cerr << "rr.rules.size(): " << rr.rules.size() << " rr.char2id.size(): " << rr.char2id.size() << endl; + for (auto rule: rr.rules) { + cerr << rule.x << " + " << rule.y << " = " << rule.z << endl; + } + for (auto x: rr.char2id) { + cerr << "id: " << x.first << " char: " << x.second << endl; + } + } + } + assert(fast_solution_model.rules == slow_solution_model.rules); + assert(fast_solution_model.char2id == slow_solution_model.char2id); + + BaseEncoder applyer(fast_solution_model, 1); + + auto inference_data = generate_text(test_size, false); + cerr << "inference_data: " << inference_data << endl; + auto fast_ids = applyer.encode_as_ids({inference_data})[0]; + auto fast_pieces = applyer.encode_as_subwords({inference_data})[0]; + auto slow_results = decode_slow(inference_data, applyer); + vector slow_pieces; + for (auto x: slow_results.pieces) { + slow_pieces.push_back(x); + } + + if (fast_ids != slow_results.ids) { + cerr << "ids real: "; + for (auto x: fast_ids) cerr << x << " "; + cerr << endl; + cerr << "ids slow: "; + for (auto x: slow_results.ids) cerr << x << " "; + cerr << endl; + cerr << "pieces real: "; + for (auto x: fast_pieces) cerr << x << " "; + cerr << endl; + cerr << "pieces slow: "; + for (auto x: slow_results.pieces) cerr << x << " "; + cerr << endl; + } + assert(fast_ids == slow_results.ids); + assert(fast_pieces == slow_pieces); + + string fast_result_one_line; + for (const auto &x: fast_pieces) fast_result_one_line += x; + string slow_result_one_line = ""; + for (const auto &x: slow_pieces) slow_result_one_line += x; + + auto original_no_space = to_no_space_tokens(inference_data); + auto fast_no_space = to_no_space_tokens(fast_result_one_line); + auto slow_no_space = to_no_space_tokens(slow_result_one_line); + + if (fast_no_space != original_no_space) { + cerr << "original_no_space: "; + for (auto x: original_no_space) { cerr << x << " "; } + cerr << endl; + cerr << "fast_no_space: "; + for (auto x: fast_no_space) { cerr << x << " "; } + cerr << endl; + cerr << "slow_no_space: "; + for (auto x: slow_no_space) { cerr << x << " "; } + cerr << endl; + } + assert(fast_no_space == original_no_space); + } +} +} + +int main(int argc, char **argv) { + if (argc == 1) { + vkcom::base_stress(-1); + } else { + int n_iter; + if (std::string(argv[1]) == "manual") { + vkcom::manual_test(); + return 0; + } + if (std::string(argv[1]) == "parallel") { + sscanf(argv[2], "%d", &n_iter); + vkcom::parallel_test(n_iter, 8); + return 0; + } + if (std::string(argv[1]) == "base") { + sscanf(argv[2], "%d", &n_iter); + vkcom::base_stress(n_iter); + return 0; + } + assert(false); + } +} + diff --git a/tests/unit_tests/stress_test.h b/tests/unit_tests/stress_test.h new file mode 100644 index 0000000..0422614 --- /dev/null +++ b/tests/unit_tests/stress_test.h @@ -0,0 +1,21 @@ +#pragma once + + +#include "../../youtokentome/cpp/third_party/flat_hash_map.h" +#include "../../youtokentome/cpp/utils.h" + + +namespace vkcom { +ska::flat_hash_map +compute_alphabet(const std::vector &data, ska::flat_hash_set &removed_chars, const BpeConfig &bpe_config); + +void remove_rare_chars(std::vector &data, const ska::flat_hash_set &removed_chars); + +BPEState learn_bpe_from_string(std::string &text_utf8, int n_tokens, const std::string &output_file, BpeConfig bpe_config); + +void utf8_to_chars(uint32_t x, std::back_insert_iterator it); + +uint32_t chars_to_utf8(const char *begin, size_t size, size_t *utf8_len); + + +} diff --git a/tests/unit_tests/test_stress.py b/tests/unit_tests/test_stress.py new file mode 100644 index 0000000..942fd97 --- /dev/null +++ b/tests/unit_tests/test_stress.py @@ -0,0 +1,44 @@ +import os +from subprocess import run + + +def compile_test(): + build_files = ["bpe.cpp", "utils.cpp", "utf8.cpp"] + files = ["../../youtokentome/cpp/" + file_name for file_name in build_files] + files.append("stress_test.cpp") + + print("compiling stress test ...") + + run( + [ + "g++", + *files, + "-o", + "test", + "-std=c++14", + "-pthread", + "-D_GLIBCXX_DEBUG", + "-DDETERMINISTIC_QUEUE", + ], + check=True, + ) + + +def test_stress(): + compile_test() + run(["./test", "base", "1000"], check=True) + os.remove("test") + + +def test_manual(): + compile_test() + run(["./test", "manual"], check=True) + os.remove("test") + os.remove("remove_it.txt") + + +def test_parallel(): + compile_test() + run(["./test", "parallel", "50"], check=True) + os.remove("test") + os.remove("remove_it.txt") diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index f808ea7..2a0fe73 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -56,7 +56,7 @@ struct VectorSegment { } // namespace vkcom namespace std { -template <> +template<> struct hash { size_t operator()(const vkcom::VectorSegment &x) const { return x.hash; } }; @@ -75,7 +75,7 @@ string fast_read_file_utf8(const string &file_name) { while (true) { size_t cur_size = res.size(); res.resize(cur_size + buf_size); - int buf_len = fread((void *)(res.data() + cur_size), 1, buf_size, fin); + int buf_len = fread((void *) (res.data() + cur_size), 1, buf_size, fin); if (buf_len < buf_size) { res.resize(res.size() - (buf_size - buf_len)); fclose(fin); @@ -99,7 +99,7 @@ bool is_space(uint32_t ch) { } uint64_t int2comb(uint32_t a, uint32_t b) { - return (static_cast(a) << 32u) + b; + return (static_cast(a) << 32u) + b; } struct MergeCandidate { @@ -109,8 +109,8 @@ struct MergeCandidate { MergeCandidate() = default; - MergeCandidate(size_t count, uint32_t left_token, uint32_t right_token) - : count(count), left_token(left_token), right_token(right_token) {} + MergeCandidate(size_t count, uint32_t left_token, uint32_t right_token) : count(count), left_token(left_token), + right_token(right_token) {} bool operator<(const MergeCandidate &other) const { if (count != other.count) { @@ -138,7 +138,7 @@ struct Position { bool operator<(const Position &other) const { return word_id < other.word_id || - (word_id == other.word_id && pos_id < other.pos_id); + (word_id == other.word_id && pos_id < other.pos_id); } }; @@ -151,12 +151,12 @@ struct PositionsCnt { vector positions; size_t cnt; }; - bool rule_intersection(BPE_Rule rule, uint32_t new_left, uint32_t new_right) { return rule.y == new_left || rule.x == new_right; } struct SmallObjectQueue { + vector> queue; bool flag_started{false}; size_t _size{0}; @@ -172,14 +172,29 @@ struct SmallObjectQueue { }; queue[event.count].push_back(event); _size++; +#ifdef DETERMINISTIC_QUEUE + if (queue.size() - 1 == event.count && flag_started) { + sort(queue.back().begin(), queue.back().end()); + } +#endif } void process_empty_slots() { +#ifdef DETERMINISTIC_QUEUE + bool moved_down = !flag_started; +#endif flag_started = true; - while (!queue.empty() && queue.back().empty()) { - queue.pop_back(); + for (; !queue.empty() && queue.back().empty(); queue.pop_back()) { +#ifdef DETERMINISTIC_QUEUE + moved_down = true; +#endif } +#ifdef DETERMINISTIC_QUEUE + if (moved_down && !queue.empty()) { + sort(queue.back().begin(), queue.back().end()); + } +#endif } bool empty() { @@ -200,7 +215,9 @@ struct SmallObjectQueue { _size--; } - size_t size() const { return _size; } + size_t size() const { + return _size; + } }; struct BigObjectQueue { @@ -209,17 +226,19 @@ struct BigObjectQueue { BigObjectQueue(size_t big_event_bound) : big_event_bound(big_event_bound) {} - void push(const MergeCandidate &event) { big_events.push_back(event); } + void push(const MergeCandidate &event) { + big_events.push_back(event); + } - bool empty() const { return big_events.empty(); } + bool empty() const { + return big_events.empty(); + } - bool top(std::function &check_cnt, MergeCandidate &ret, - SmallObjectQueue *small_object_queue, BPE_Rule last_rule) { + bool top(std::function &check_cnt, MergeCandidate &ret, SmallObjectQueue *small_object_queue, + BPE_Rule last_rule) { for (size_t i = 0; i < big_events.size();) { - if (!rule_intersection(last_rule, big_events[i].left_token, - big_events[i].right_token)) { - uint64_t comb = - int2comb(big_events[i].left_token, big_events[i].right_token); + if (!rule_intersection(last_rule, big_events[i].left_token, big_events[i].right_token)) { + uint64_t comb = int2comb(big_events[i].left_token, big_events[i].right_token); assert(big_events[i].count >= check_cnt(comb)); big_events[i].count = check_cnt(comb); } @@ -232,11 +251,15 @@ struct BigObjectQueue { i++; } } +#ifdef DETERMINISTIC_QUEUE + sort(big_events.begin(), big_events.end()); /// TODO remove unoptimal code +#else for (auto &big_event : big_events) { if (big_event.count > big_events.back().count) { std::swap(big_event, big_events.back()); } } +#endif if (big_events.empty()) { return false; @@ -250,7 +273,9 @@ struct BigObjectQueue { big_events.pop_back(); } - size_t size() const { return big_events.size(); } + size_t size() const { + return big_events.size(); + } }; struct PriorityQueue { @@ -258,9 +283,8 @@ struct PriorityQueue { BigObjectQueue big_queue; size_t big_event_bound; - explicit PriorityQueue(size_t dataset_size) - : big_queue(static_cast(sqrt(dataset_size))), - big_event_bound(static_cast(sqrt(dataset_size))) {} + explicit PriorityQueue(size_t dataset_size) : big_queue(static_cast(sqrt(dataset_size))), + big_event_bound(static_cast(sqrt(dataset_size))) {} void push(const MergeCandidate &event) { if (event.count == 0) { @@ -277,8 +301,7 @@ struct PriorityQueue { return big_queue.empty() && small_queue.empty(); } - MergeCandidate top(std::function &check_cnt, - BPE_Rule last_rule) { + MergeCandidate top(std::function &check_cnt, BPE_Rule last_rule) { MergeCandidate res; bool has_top = big_queue.top(check_cnt, res, &small_queue, last_rule); if (has_top) { @@ -295,7 +318,9 @@ struct PriorityQueue { } } - size_t size() const { return big_queue.size() + small_queue.size(); } + size_t size() const { + return big_queue.size() + small_queue.size(); + } }; ska::flat_hash_map compute_alphabet_helper( @@ -311,9 +336,9 @@ ska::flat_hash_map compute_alphabet_helper( size_t cur = 0; size_t n_removed = 0; for (; cur < frequencies.size() && - (data_len - n_removed - frequencies[cur].first) > - data_len * bpe_config.character_coverage; - cur++) { + (data_len - n_removed - frequencies[cur].first) > + data_len * bpe_config.character_coverage; + cur++) { n_removed += frequencies[cur].first; } std::cerr << "number of unique characters in the training data: " @@ -461,9 +486,9 @@ void time_check(const string &message) { if (!message.empty()) { std::cerr << "## time " << message << " ... " << std::chrono::duration_cast( - cur_moment - last_time_stamp) - .count() * - 1.0 / 1e6 + cur_moment - last_time_stamp) + .count() * + 1.0 / 1e6 << std::endl; } last_time_stamp = cur_moment; @@ -472,9 +497,9 @@ void time_check(const string &message) { double time_check_silent() { auto cur_moment = std::chrono::steady_clock::now(); double ret = std::chrono::duration_cast( - cur_moment - last_time_stamp) - .count() * - 1.0 / 1e6; + cur_moment - last_time_stamp) + .count() * + 1.0 / 1e6; last_time_stamp = cur_moment; return ret; } @@ -579,7 +604,7 @@ void worker_doing_merge( auto self_full_remove = [&](size_t word_id, size_t pos_id) { uint64_t comb = get_self_code(word_id, pos_id); uint32_t real_cnt = word_freq[word_id] * - pairsInSeg(lists_of_tokens[word_id][pos_id].seg_len); + pairsInSeg(lists_of_tokens[word_id][pos_id].seg_len); pair2cnt[comb] -= real_cnt; }; @@ -608,7 +633,7 @@ void worker_doing_merge( std::unique_lock ul(mt[thread_id]); cv[thread_id].wait(ul, [&] { return task_order[cur_token_rule % 2].z == cur_token_rule || - cur_token_rule >= real_n_tokens; + cur_token_rule >= real_n_tokens; }); assert(cur_token_rule <= real_n_tokens); if (cur_token_rule == real_n_tokens) { @@ -851,7 +876,7 @@ BPEState learn_bpe_from_string(string &text_utf8, int n_tokens, for (size_t i = 1; i <= n_threads; i++) { size_t candidate = text_utf8.size() * i / n_threads; for (; candidate < text_utf8.size() && !is_space(text_utf8[candidate]); - candidate++) { + candidate++) { } split_pos.push_back(candidate); @@ -1034,7 +1059,7 @@ BPEState learn_bpe_from_string(string &text_utf8, int n_tokens, size_t used_ids = char2id.size() + bpe_config.special_tokens.n_special_tokens(); - if (used_ids > (size_t)n_tokens) { + if (used_ids > (size_t) n_tokens) { std::cerr << "Incorrect arguments. Vocabulary size too small. Set vocab_size>=" << used_ids << ". Current value for vocab_size=" << n_tokens @@ -1100,12 +1125,12 @@ BPEState learn_bpe_from_string(string &text_utf8, int n_tokens, int inter_fail = 0; int equal_fail = 0; vector> progress_debug; - while (used_ids < (size_t)n_tokens) { + while (used_ids < (size_t) n_tokens) { uint32_t x, y, z; assert(finished_cur <= used_ids && used_ids <= finished_cur + 2); bool progress = false; - if (used_ids < (size_t)n_tokens && used_ids - finished_cur < 2 && + if (used_ids < (size_t) n_tokens && used_ids - finished_cur < 2 && last_failed_try < finished_cur) { progress = true; for (size_t i = 0; i < n_threads; i++) { @@ -1132,18 +1157,18 @@ BPEState learn_bpe_from_string(string &text_utf8, int n_tokens, } } BPE_Rule last_rule = (used_ids - finished_cur == 1) - ? rules.back() - : BPE_Rule({0, 0, 0}); + ? rules.back() + : BPE_Rule({0, 0, 0}); auto merge_event = merge_order.top(check_cnt, last_rule); if ((used_ids - finished_cur == 1) && (merge_event.left_token == rules.back().y || - merge_event.right_token == rules.back().x || - (!rules.empty() && rules.back().x == rules.back().y))) { + merge_event.right_token == rules.back().x || + (!rules.empty() && rules.back().x == rules.back().y))) { inter_fail += merge_event.left_token == rules.back().y || - merge_event.right_token == rules.back().x; + merge_event.right_token == rules.back().x; equal_fail += !rules.empty() && rules.back().x == rules.back().y && - used_ids - finished_cur == 1; + used_ids - finished_cur == 1; last_failed_try = finished_cur; x = y = z = 0; @@ -1342,7 +1367,7 @@ void check_config(BpeConfig &bpe_config, int vocab_size) { bpe_config.special_tokens.unk_id >= vocab_size) { std::cerr << "Invalid value. unk_id: must be in the range [0, vocab_size - " "1]. Current value of vocab_size = " + - std::to_string(vocab_size) + "." + std::to_string(vocab_size) + "." << std::endl; exit(EXIT_FAILURE); } @@ -1351,7 +1376,7 @@ void check_config(BpeConfig &bpe_config, int vocab_size) { bpe_config.special_tokens.pad_id >= vocab_size) { std::cerr << "Invalid value. pad_id must be in the range [-1, vocab_size - " "1]. Current value of vocab_size = " + - std::to_string(vocab_size) + "." + std::to_string(vocab_size) + "." << std::endl; exit(EXIT_FAILURE); } @@ -1360,7 +1385,7 @@ void check_config(BpeConfig &bpe_config, int vocab_size) { bpe_config.special_tokens.bos_id >= vocab_size) { std::cerr << "Invalid value. bos_id must be in the range [-1, vocab_size - " "1]. Current value of vocab_size = " + - std::to_string(vocab_size) + "." + std::to_string(vocab_size) + "." << std::endl; exit(EXIT_FAILURE); } @@ -1369,7 +1394,7 @@ void check_config(BpeConfig &bpe_config, int vocab_size) { bpe_config.special_tokens.eos_id >= vocab_size) { std::cerr << "Invalid value. eos_id must be in the range [-1, vocab_size - " "1]. Current value of vocab_size = " + - std::to_string(vocab_size) + "." + std::to_string(vocab_size) + "." << std::endl; exit(EXIT_FAILURE); } @@ -1450,7 +1475,7 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, bool operator<(const MergeEvent2 &other) const { return priority > other.priority || - (priority == other.priority && pos > other.pos); + (priority == other.priority && pos > other.pos); } }; @@ -1595,7 +1620,6 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, std::reverse(output_pieces.begin(), output_pieces.end()); } } - return {output_ids, output_pieces}; } @@ -1618,7 +1642,7 @@ BaseEncoder::BaseEncoder(const string &model_path, int _n_threads) } } -template +template vector concat_vectors(const vector &a, const vector &b) { vector c; c.reserve(a.size() + b.size()); @@ -1632,7 +1656,7 @@ void BaseEncoder::fill_from_state() { id2char[x.second] = x.first; } - for (int i = 0; i < (int)bpe_state.rules.size(); i++) { + for (int i = 0; i < (int) bpe_state.rules.size(); i++) { rule2id[int2comb(bpe_state.rules[i].x, bpe_state.rules[i].y)] = i; } @@ -1654,7 +1678,7 @@ void BaseEncoder::fill_from_state() { int BaseEncoder::vocab_size() const { return bpe_state.rules.size() + bpe_state.char2id.size() + - bpe_state.special_tokens.n_special_tokens(); + bpe_state.special_tokens.n_special_tokens(); } std::vector BaseEncoder::encode_parallel( @@ -1730,8 +1754,8 @@ string BaseEncoder::id_to_subword(int id, bool replace_space) const { if (id < 0 || vocab_size() <= id) { std::cerr << "Error: Invalid value for id. id must be in the range [0, " "vocab_size - 1]. Current value: vocab_size = " + - std::to_string(vocab_size()) + - "; id=" + std::to_string(id) + ";" + std::to_string(vocab_size()) + + "; id=" + std::to_string(id) + ";" << std::endl; exit(EXIT_FAILURE); } diff --git a/youtokentome/cpp/utf8.h b/youtokentome/cpp/utf8.h index 2ee4382..ec66ca9 100644 --- a/youtokentome/cpp/utf8.h +++ b/youtokentome/cpp/utf8.h @@ -3,11 +3,15 @@ #include "utils.h" namespace vkcom { + std::string encode_utf8(const std::vector &utext); std::vector decode_utf8(const char *begin, const char *end); -std::vector decode_utf8(const std::string& utf8_text); +std::vector decode_utf8(const std::string &utf8_text); + + + } // namespace vkcom diff --git a/youtokentome/cpp/utils.cpp b/youtokentome/cpp/utils.cpp index 2c77619..901e6ef 100644 --- a/youtokentome/cpp/utils.cpp +++ b/youtokentome/cpp/utils.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include namespace vkcom { diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index d65b4d0..d45346c 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -18,7 +18,7 @@ struct BPE_Rule { BPE_Rule(uint32_t x, uint32_t y, uint32_t z); - bool operator==(const BPE_Rule& other) const; + bool operator==(const BPE_Rule &other) const; }; struct SpecialTokens { @@ -31,9 +31,9 @@ struct SpecialTokens { SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id); - void dump(std::ofstream& fout); + void dump(std::ofstream &fout); - void load(std::ifstream& fin); + void load(std::ifstream &fin); uint32_t max_id() const; @@ -50,7 +50,7 @@ struct BpeConfig { BpeConfig() = default; BpeConfig(double character_coverage, int n_threads, - const SpecialTokens& special_tokens); + const SpecialTokens &special_tokens); }; struct BPEState { @@ -58,9 +58,9 @@ struct BPEState { std::vector rules; SpecialTokens special_tokens; - void dump(const std::string& file_name); + void dump(const std::string &file_name); - void load(const std::string& file_name); + void load(const std::string &file_name); }; struct DecodeResult { @@ -77,12 +77,12 @@ struct EncodingConfig { bool is_space(uint32_t ch); std::vector read_lines_from_stdin(size_t batch_limit, - size_t* processed); + size_t *processed); -template -void write_to_stdout(const std::vector>& sentences, bool flush) { - for (const auto& sentence : sentences) { - for (const auto& token : sentence) { +template +void write_to_stdout(const std::vector> &sentences, bool flush) { + for (const auto &sentence : sentences) { + for (const auto &token : sentence) { std::cout << token << " "; } std::cout << "\n"; diff --git a/youtokentome/cpp/yttm.cpp b/youtokentome/cpp/yttm.cpp index 116ee9f..720ae07 100644 --- a/youtokentome/cpp/yttm.cpp +++ b/youtokentome/cpp/yttm.cpp @@ -1,4 +1,4 @@ -/* Generated by Cython 0.29.12 */ +/* Generated by Cython 0.29.13 */ /* BEGIN: Cython Metadata { @@ -37,8 +37,8 @@ END: Cython Metadata */ #elif PY_VERSION_HEX < 0x02060000 || (0x03000000 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x03030000) #error Cython requires Python 2.6+ or Python 3.3+. #else -#define CYTHON_ABI "0_29_12" -#define CYTHON_HEX_VERSION 0x001D0CF0 +#define CYTHON_ABI "0_29_13" +#define CYTHON_HEX_VERSION 0x001D0DF0 #define CYTHON_FUTURE_DIVISION 0 #include #ifndef offsetof @@ -3290,7 +3290,6 @@ static PyObject *__pyx_pf_20_youtokentome_cython_3BPE_22vocab_cli(struct __pyx_o * def vocab_cli(self, verbose): * self.encoder.vocab_cli(verbose) # <<<<<<<<<<<<<< * - * */ __pyx_t_1 = __Pyx_PyObject_IsTrue(__pyx_v_verbose); if (unlikely((__pyx_t_1 == ((bool)-1)) && PyErr_Occurred())) __PYX_ERR(0, 128, __pyx_L1_error) __pyx_v_self->encoder->vocab_cli(__pyx_t_1);