diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index 2a0fe73..fa748f5 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -865,7 +865,7 @@ void rename_tokens(ska::flat_hash_map &char2id, } BPEState learn_bpe_from_string(string &text_utf8, int n_tokens, - const string &output_file, + StreamWriter &output, BpeConfig bpe_config) { vector threads; assert(bpe_config.n_threads >= 1 || bpe_config.n_threads == -1); @@ -1294,8 +1294,8 @@ BPEState learn_bpe_from_string(string &text_utf8, int n_tokens, rename_tokens(char2id, rules, bpe_config.special_tokens, n_tokens); BPEState bpe_state = {char2id, rules, bpe_config.special_tokens}; - bpe_state.dump(output_file); - std::cerr << "model saved to: " << output_file << std::endl; + bpe_state.dump(output); + std::cerr << "model saved to: " << output.name() << std::endl; return bpe_state; } @@ -1450,7 +1450,8 @@ void train_bpe(const string &input_path, const string &model_path, std::cerr << "reading file..." << std::endl; auto data = fast_read_file_utf8(input_path); std::cerr << "learning bpe..." << std::endl; - learn_bpe_from_string(data, vocab_size, model_path, bpe_config); + auto fout = StreamWriter.open(model_path); + learn_bpe_from_string(data, vocab_size, fout, bpe_config); } DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, diff --git a/youtokentome/cpp/bpe.h b/youtokentome/cpp/bpe.h index dff9326..bd5ebc6 100644 --- a/youtokentome/cpp/bpe.h +++ b/youtokentome/cpp/bpe.h @@ -32,7 +32,7 @@ class BaseEncoder { explicit BaseEncoder(BPEState bpe_state, int _n_threads); - explicit BaseEncoder(const std::string& model_path, int n_threads); + explicit BaseEncoder(const StreamReader& model_path, int n_threads); void fill_from_state(); diff --git a/youtokentome/cpp/utils.cpp b/youtokentome/cpp/utils.cpp index b19f88b..a662c6c 100644 --- a/youtokentome/cpp/utils.cpp +++ b/youtokentome/cpp/utils.cpp @@ -11,6 +11,62 @@ namespace vkcom { using std::string; using std::vector; +class FileWriter : public StreamWriter { + public: + FileWriter(const std::string &file_name) { + this->file_name = file_name; + this->fout = std::ofstream(file_name, std::ios::out | std::ios::binary); + if (fout.fail()) { + std::cerr << "Can't open file: " << file_name << std::endl; + assert(false); + } + } + + virtual int write(const char *buffer, int size) override { + return fout.write(buffer, size); + } + + virtual std::string name() const noexcept override { + return file_name; + } + + private: + std::string file_name; + std::ofstream fout; +}; + +class FileReader : public StreamReader { + public: + FileReader(const std::string &file_name) { + this->file_name = file_name; + this->fin = std::ifstream(file_name, std::ios::in | std::ios::binary); + if (fin.fail()) { + std::cerr << "Can't open file: " << file_name << std::endl; + assert(false); + } + } + + virtual int read(const char *buffer, int size) override { + return fin.read(buffer, size); + } + + virtual std::string name() const noexcept override { + return file_name; + } + + private: + std::string file_name; + std::ifstream fin; +}; + +StreamWriter StreamWriter::open(const std::string &file_name) { + return FileWriter(file_name); +} + +StreamReader StreamReader::open(const std::string &file_name) { + return FileReader(file_name); +} + template::value, int>::type = 0> T bin_to_int(const char *val) { uint32_t ret = static_cast(val[0]); @@ -31,7 +87,7 @@ std::unique_ptr int_to_bin(T val) { return std::move(ret); } -void SpecialTokens::dump(std::ofstream &fout) { +void SpecialTokens::dump(StreamWriter &fout) { std::unique_ptr unk_id_ptr(int_to_bin(unk_id)), pad_id_ptr(int_to_bin(pad_id)), bos_id_ptr(int_to_bin(bos_id)), @@ -42,7 +98,7 @@ void SpecialTokens::dump(std::ofstream &fout) { fout.write(eos_id_ptr.get(), 4); } -void SpecialTokens::load(std::ifstream &fin) { +void SpecialTokens::load(StreamReader &fin) { char unk_id_bs[4], pad_id_bs[4], bos_id_bs[4], eos_id_bs[4]; fin.read(unk_id_bs, 4); fin.read(pad_id_bs, 4); @@ -85,13 +141,7 @@ bool BPE_Rule::operator==(const BPE_Rule &other) const { BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {} -void BPEState::dump(const string &file_name) { - std::ofstream fout(file_name, std::ios::out | std::ios::binary); - if (fout.fail()) { - std::cerr << "Can't open file: " << file_name << std::endl; - assert(false); - } - +void BPEState::dump(StreamWriter &fout) { std::unique_ptr char2id_ptr(int_to_bin(char2id.size())), rules_ptr(int_to_bin(rules.size())); fout.write(char2id_ptr.get(), 4); @@ -115,18 +165,11 @@ void BPEState::dump(const string &file_name) { fout.write(rule_ptr.get(), 4); } special_tokens.dump(fout); - fout.close(); } -void BPEState::load(const string &file_name) { +void BPEState::load(StreamReader &fin) { char2id.clear(); rules.clear(); - std::ifstream fin(file_name, std::ios::in | std::ios::binary); - if (fin.fail()) { - std::cerr << "Error. Can not open file with model: " << file_name - << std::endl; - exit(EXIT_FAILURE); - } char n_bs[4], m_bs[4]; fin.read(n_bs, 4); fin.read(m_bs, 4); @@ -161,7 +204,6 @@ void BPEState::load(const string &file_name) { rules.emplace_back(std::get<0>(rules_xyz[i]), std::get<1>(rules_xyz[i]), std::get<2>(rules_xyz[i])); } special_tokens.load(fin); - fin.close(); } BpeConfig::BpeConfig(double _character_coverage, int _n_threads, diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index d45346c..6681b5c 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -8,6 +8,20 @@ namespace vkcom { const uint32_t SPACE_TOKEN = 9601; +struct StreamWriter { + virtual int write(const char *buffer, int size) = 0; + virtual std::string name() const noexcept = 0; + + static StreamWriter open(const std::string &file_name); +}; + +struct StreamReader { + virtual int read(const char *buffer, int size) = 0; + virtual std::string name() const noexcept = 0; + + static StreamReader open(const std::string &file_name); +}; + struct BPE_Rule { // x + y -> z uint32_t x{0}; @@ -31,9 +45,9 @@ struct SpecialTokens { SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id); - void dump(std::ofstream &fout); + void dump(StreamWriter &fout); - void load(std::ifstream &fin); + void load(StreamReader &fin); uint32_t max_id() const; @@ -58,9 +72,9 @@ struct BPEState { std::vector rules; SpecialTokens special_tokens; - void dump(const std::string &file_name); + void dump(StreamWriter &fout); - void load(const std::string &file_name); + void load(StreamReader &fin); }; struct DecodeResult { diff --git a/youtokentome/youtokentome.py b/youtokentome/youtokentome.py index e02c84a..a9df041 100644 --- a/youtokentome/youtokentome.py +++ b/youtokentome/youtokentome.py @@ -1,22 +1,32 @@ from enum import Enum -from typing import List, Union +from functools import wraps +from typing import BinaryIO, List, Optional, Union import _youtokentome_cython + class OutputType(Enum): ID = 1 SUBWORD = 2 + class BPE: - def __init__(self, model: str, n_threads: int = -1): - self.bpe_cython = _youtokentome_cython.BPE( - model_path=model, n_threads=n_threads - ) + def __init__(self, model: Union[str, BinaryIO], n_threads: int = -1): + own_obj = isinstance(model, str) + if own_obj: + model = open(model, "rb") + try: + self.bpe_cython = _youtokentome_cython.BPE( + model_fobj=model, n_threads=n_threads + ) + finally: + if own_obj: + model.close() @staticmethod def train( data: str, - model: str, + model: Optional[Union[str, BinaryIO]], vocab_size: int, coverage: float = 1.0, n_threads: int = -1, @@ -25,17 +35,24 @@ def train( bos_id: int = 2, eos_id: int = 3, ) -> "BPE": - _youtokentome_cython.BPE.train( - data=data, - model=model, - vocab_size=vocab_size, - n_threads=n_threads, - coverage=coverage, - pad_id=pad_id, - unk_id=unk_id, - bos_id=bos_id, - eos_id=eos_id, - ) + own_obj = isinstance(model, str) + if own_obj: + model = open(model, "wb") + try: + _youtokentome_cython.BPE.train( + data=data, + model=model, + vocab_size=vocab_size, + n_threads=n_threads, + coverage=coverage, + pad_id=pad_id, + unk_id=unk_id, + bos_id=bos_id, + eos_id=eos_id, + ) + finally: + if own_obj: + model.close() return BPE(model=model, n_threads=n_threads) @@ -61,6 +78,22 @@ def encode( reverse=reverse, ) + def save(self, where: Union[str, BinaryIO]): + """ + Write the model to FS or any writeable file object. + + :param where: FS path or writeable file object. + :return: None + """ + own_obj = isinstance(where, str) + if own_obj: + where = open(where, "wb") + try: + self.bpe_cython.save(where=where) + finally: + if own_obj: + where.close() + def vocab_size(self) -> int: return self.bpe_cython.vocab_size()