From a1555a024cce7c1f24a8693b09be9497b7241051 Mon Sep 17 00:00:00 2001 From: Vadim Markovtsev Date: Thu, 10 Oct 2019 13:30:23 +0200 Subject: [PATCH] Change the model format to binary We are using big endian numbers under the hood. Rule's x, y and z are written by plane, not interleaved. Signed-off-by: Vadim Markovtsev --- youtokentome/cpp/utils.cpp | 110 ++++++++++++++++++++++++++++++------- 1 file changed, 91 insertions(+), 19 deletions(-) diff --git a/youtokentome/cpp/utils.cpp b/youtokentome/cpp/utils.cpp index 901e6ef..0d8ea06 100644 --- a/youtokentome/cpp/utils.cpp +++ b/youtokentome/cpp/utils.cpp @@ -2,20 +2,56 @@ #include #include #include +#include #include #include + namespace vkcom { using std::string; using std::vector; +template::value, int>::type = 0> +T bin_to_int(const char *val) { + uint32_t ret = static_cast(val[0]); + ret |= static_cast(static_cast(val[1])) << 8; + ret |= static_cast(static_cast(val[2])) << 16; + ret |= static_cast(static_cast(val[3])) << 24; + return ret; +} + +template::value, int>::type = 0> +std::unique_ptr int_to_bin(T val) { + auto u32 = static_cast(val); + std::unique_ptr ret(new char[4]); + ret[0] = u32 & 0xFF; + ret[1] = (u32 >> 8) & 0xFF; + ret[2] = (u32 >> 16) & 0xFF; + ret[3] = (u32 >> 24); // no need for & 0xFF + return std::move(ret); +} + void SpecialTokens::dump(std::ofstream &fout) { - fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id - << std::endl; + std::unique_ptr unk_id_ptr(int_to_bin(unk_id)), + pad_id_ptr(int_to_bin(pad_id)), + bos_id_ptr(int_to_bin(bos_id)), + eos_id_ptr(int_to_bin(eos_id)); + fout.write(unk_id_ptr.get(), 4); + fout.write(pad_id_ptr.get(), 4); + fout.write(bos_id_ptr.get(), 4); + fout.write(eos_id_ptr.get(), 4); } void SpecialTokens::load(std::ifstream &fin) { - fin >> unk_id >> pad_id >> bos_id >> eos_id; + char unk_id_bs[4], pad_id_bs[4], bos_id_bs[4], eos_id_bs[4]; + fin.read(unk_id_bs, 4); + fin.read(pad_id_bs, 4); + fin.read(bos_id_bs, 4); + fin.read(eos_id_bs, 4); + this->unk_id = bin_to_int(unk_id_bs); + this->pad_id = bin_to_int(pad_id_bs); + this->bos_id = bin_to_int(bos_id_bs); + this->eos_id = bin_to_int(eos_id_bs); } uint32_t SpecialTokens::max_id() const { @@ -50,18 +86,33 @@ bool BPE_Rule::operator==(const BPE_Rule &other) const { BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {} void BPEState::dump(const string &file_name) { - std::ofstream fout(file_name, std::ios::out); + std::ofstream fout(file_name, std::ios::out | std::ios::binary); if (fout.fail()) { std::cerr << "Can't open file: " << file_name << std::endl; assert(false); } - fout << char2id.size() << " " << rules.size() << std::endl; - for (auto s : char2id) { - fout << s.first << " " << s.second << std::endl; - } - for (auto rule : rules) { - fout << rule.x << " " << rule.y << " " << rule.z << std::endl; + std::unique_ptr char2id_ptr(int_to_bin(char2id.size())), + rules_ptr(int_to_bin(rules.size())); + fout.write(char2id_ptr.get(), 4); + fout.write(rules_ptr.get(), 4); + for (auto &s : char2id) { + std::unique_ptr first_ptr(int_to_bin(s.first)), + second_ptr(int_to_bin(s.second)); + fout.write(first_ptr.get(), 4); + fout.write(second_ptr.get(), 4); + } + for (auto &rule : rules) { + std::unique_ptr rule_ptr(int_to_bin(rule.x)); + fout.write(rule_ptr.get(), 4); + } + for (auto &rule : rules) { + std::unique_ptr rule_ptr(int_to_bin(rule.y)); + fout.write(rule_ptr.get(), 4); + } + for (auto &rule : rules) { + std::unique_ptr rule_ptr(int_to_bin(rule.z)); + fout.write(rule_ptr.get(), 4); } special_tokens.dump(fout); fout.close(); @@ -70,24 +121,45 @@ void BPEState::dump(const string &file_name) { void BPEState::load(const string &file_name) { char2id.clear(); rules.clear(); - std::ifstream fin(file_name, std::ios::in); + std::ifstream fin(file_name, std::ios::in | std::ios::binary); if (fin.fail()) { std::cerr << "Error. Can not open file with model: " << file_name << std::endl; exit(EXIT_FAILURE); } - int n, m; - fin >> n >> m; + char n_bs[4], m_bs[4]; + fin.read(n_bs, 4); + fin.read(m_bs, 4); + auto n = bin_to_int(n_bs); + auto m = bin_to_int(m_bs); for (int i = 0; i < n; i++) { - uint32_t inner_id; - uint32_t utf32_id; - fin >> inner_id >> utf32_id; + char inner_id_bs[4], utf32_id_bs[4]; + fin.read(inner_id_bs, 4); + fin.read(utf32_id_bs, 4); + auto inner_id = bin_to_int(inner_id_bs); + auto utf32_id = bin_to_int(utf32_id_bs); char2id[inner_id] = utf32_id; } + std::vector> rules_xyz(m); + for (int j = 0; j < 3; j++) { + for (int i = 0; i < m; i++) { + char val[4]; + fin.read(val, 4); + uint32_t *element; + switch (j) { + case 0: + element = &std::get<0>(rules_xyz[i]); + case 1: + element = &std::get<1>(rules_xyz[i]); + case 2: + element = &std::get<2>(rules_xyz[i]); + } + //std::cout << bin_to_int(val) << std::endl; + *element = bin_to_int(val); + } + } for (int i = 0; i < m; i++) { - uint32_t x, y, z; - fin >> x >> y >> z; - rules.emplace_back(x, y, z); + rules.emplace_back(std::get<0>(rules_xyz[i]), std::get<1>(rules_xyz[i]), std::get<2>(rules_xyz[i])); } special_tokens.load(fin); fin.close();