diff --git a/README.md b/README.md index 72b268c..63ceded 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ Class constructor. Loads the trained model. Class `youtokentome.BPE` has the following methods: #### encode ```python -encode(self, sentences, output_type=yttm.OutputType.ID, bos=False, eos=False, reverse=False) +encode(self, sentences, output_type=yttm.OutputType.ID, bos=False, eos=False, reverse=False, dropout_prob=0) ``` **Args:** @@ -117,6 +117,7 @@ encode(self, sentences, output_type=yttm.OutputType.ID, bos=False, eos=False, re * `bos`: bool, if True then token “beginning of sentence” will be added * `eos`: bool, if True then token “end of sentence” will be added * `reverse`: bool, if True the output sequence of tokens will be reversed +* `dropout_prob`: float, BPE-dropout probability (the probability of a merge being dropped). Must be in the range [0, 1]. **Returns:** If `output_type` is equal to `youtokentome.OutputType.ID` or `youtokentome.OutputType.SUBWORD` @@ -258,6 +259,7 @@ Options: --eos Add tab 'end of sentence'. --reverse Reverse output sequence of tokens. --stream Process each line before reading the next one. + --dropout_prob BPE-dropout probability (the probability of a merge being dropped). [default: 0] --help Show this message and exit. ``` diff --git a/setup.py b/setup.py index 5633b70..a83cff6 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ setup( name="youtokentome", - version="1.0.3", + version="1.0.4", packages=find_packages(), description="Unsupervised text tokenizer focused on computational efficiency", long_description=LONG_DESCRIPTION, @@ -49,4 +49,5 @@ "Programming Language :: C++", ], ext_modules=cythonize(extensions), -) \ No newline at end of file +) + diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index 0c0a512..61e28df 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include "third_party/flat_hash_map.h" @@ -1459,6 +1460,71 @@ Status train_bpe(const string &input_path, const string &model_path, return Status(); } + +template +class BasePriorityQueue { + public: + virtual void push(T x) = 0; + virtual bool pop(T& x) = 0; + virtual ~BasePriorityQueue() {} +}; + +template +class STLQueue : public BasePriorityQueue { + std::priority_queue q; + void push(T x) override { + q.push(x); + } + bool pop(T& x) override { + if (q.empty()) { + return false; + } + x = q.top(); + q.pop(); + return true; + } +}; + +std::mt19937 rnd; + +template +class DropoutQueue : public BasePriorityQueue { + double skip_prob; + std::uniform_real_distribution<> dist; + std::priority_queue q; + vector skipped_elements; + public: + explicit DropoutQueue(double _skip_prob):skip_prob(_skip_prob), dist(std::uniform_real_distribution<>(0, 1)) {} + void push(T x) override { + q.push(x); + } + bool pop(T& x) override { + assert(skipped_elements.empty()); + while (true) { + if (q.empty()) { + for (auto y: skipped_elements) { + q.push(y); + } + skipped_elements.clear(); + return false; + } + T temp = q.top(); + q.pop(); + if (dist(rnd) < skip_prob) { + skipped_elements.push_back(temp); + } + else { + for (auto y: skipped_elements) { + q.push(y); + } + skipped_elements.clear(); + x = temp; + return true; + } + } + } +}; + DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, const EncodingConfig &encoding_config, OutputType output_type) const { @@ -1539,17 +1605,24 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, } list.back().next = -1; - std::priority_queue queue; auto pair_code = [&](uint64_t first_pos) { auto second_pos = list[first_pos].next; return int2comb(list[first_pos].token_id, list[second_pos].token_id); }; + std::unique_ptr> queue(nullptr); + if (encoding_config.dropout_prob == 0) { + queue.reset(new STLQueue()); + } + else { + queue.reset(new DropoutQueue(encoding_config.dropout_prob)); + } + auto push_in_queue_if_rule_exist = [&](uint64_t pos) { auto it = rule2id.find(pair_code(pos)); if (it != rule2id.end()) { - queue.push({it->second, static_cast(pos)}); + queue->push({it->second, static_cast(pos)}); } }; @@ -1557,9 +1630,11 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, push_in_queue_if_rule_exist(j); } - while (!queue.empty()) { - MergeEvent2 event = queue.top(); - queue.pop(); + while (true) { + MergeEvent2 event; + if (!queue->pop(event)) { + break; + } int rule_id = event.priority; int pos_1 = event.pos; int pos_2 = list[pos_1].next; @@ -1737,8 +1812,8 @@ Status BaseEncoder::encode_parallel( Status BaseEncoder::encode_as_ids(const vector &sentences, vector> *ids, bool bos, bool eos, - bool reverse) const { - EncodingConfig encoding_config = {bos, eos, reverse}; + bool reverse, double dropout_prob) const { + EncodingConfig encoding_config = {bos, eos, reverse, dropout_prob}; std::vector decode_results; Status status = encode_parallel(sentences, encoding_config, ID, &decode_results); @@ -1755,9 +1830,9 @@ Status BaseEncoder::encode_as_ids(const vector &sentences, vector &sentences, vector> *subwords, - bool bos, bool eos, bool reverse) const { + bool bos, bool eos, bool reverse, double dropout_prob) const { time_check(""); - EncodingConfig encoding_config = {bos, eos, reverse}; + EncodingConfig encoding_config = {bos, eos, reverse, dropout_prob}; std::vector decode_results; Status status = encode_parallel(sentences, encoding_config, SUBWORD, &decode_results); if (!status.ok()) { @@ -1939,7 +2014,7 @@ void BaseEncoder::vocab_cli(bool verbose) const { } Status BaseEncoder::encode_cli(const string &output_type_str, bool stream, - bool bos, bool eos, bool reverse) const { + bool bos, bool eos, bool reverse, double dropout_prob) const { std::ios_base::sync_with_stdio(false); OutputType output_type; if (output_type_str == "id") { @@ -1953,7 +2028,7 @@ Status BaseEncoder::encode_cli(const string &output_type_str, bool stream, string sentence; while (getline(std::cin, sentence)) { vector> subwords; - Status status = encode_as_subwords({sentence}, &subwords, bos, eos, reverse); + Status status = encode_as_subwords({sentence}, &subwords, bos, eos, reverse, dropout_prob); if (!status.ok()) { return status; } @@ -1964,7 +2039,7 @@ Status BaseEncoder::encode_cli(const string &output_type_str, bool stream, string sentence; while (getline(std::cin, sentence)) { vector> ids; - Status status = encode_as_ids({sentence}, &ids, bos, eos, reverse); + Status status = encode_as_ids({sentence}, &ids, bos, eos, reverse, dropout_prob); if (!status.ok()) { return status; } @@ -1983,7 +2058,7 @@ Status BaseEncoder::encode_cli(const string &output_type_str, bool stream, auto sentences = read_lines_from_stdin(batch_limit, &processed); if (output_type == SUBWORD) { vector> subwords; - Status status = encode_as_subwords(sentences, &subwords, bos, eos, reverse); + Status status = encode_as_subwords(sentences, &subwords, bos, eos, reverse, dropout_prob); if (!status.ok()) { return status; } @@ -1991,7 +2066,7 @@ Status BaseEncoder::encode_cli(const string &output_type_str, bool stream, } else { assert(output_type == ID); vector> ids; - Status status = encode_as_ids(sentences, &ids, bos, eos, reverse); + Status status = encode_as_ids(sentences, &ids, bos, eos, reverse, dropout_prob); if (!status.ok()) { return status; } diff --git a/youtokentome/cpp/bpe.h b/youtokentome/cpp/bpe.h index af74e14..28f1d5c 100644 --- a/youtokentome/cpp/bpe.h +++ b/youtokentome/cpp/bpe.h @@ -38,13 +38,13 @@ class BaseEncoder { Status encode_as_ids( const std::vector &sentences, std::vector> *ids, bool bos = false, - bool eos = false, bool reverse = false) const; + bool eos = false, bool reverse = false, double dropout_prob=0) const; Status encode_as_subwords( const std::vector &sentences, std::vector> *subwords, bool bos = false, - bool eos = false, bool reverse = false) const; + bool eos = false, bool reverse = false, double dropout_prob=0) const; Status id_to_subword(int id, std::string *subword, bool replace_space = false) const; @@ -65,7 +65,7 @@ class BaseEncoder { std::vector vocabulary() const; Status encode_cli(const std::string &output_type, bool stream, bool bos = false, - bool eos = false, bool reverse = false) const; + bool eos = false, bool reverse = false, double dropout_prob = 0) const; Status decode_cli(const std::unordered_set *ignore_ids) const; diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index b30b5b6..9d8b79d 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -82,6 +82,7 @@ struct EncodingConfig { bool bos; bool eos; bool reverse; + double dropout_prob; }; bool is_space(uint32_t ch); diff --git a/youtokentome/cpp/yttm.pyx b/youtokentome/cpp/yttm.pyx index edb7fcd..1d7774d 100644 --- a/youtokentome/cpp/yttm.pyx +++ b/youtokentome/cpp/yttm.pyx @@ -32,10 +32,10 @@ cdef extern from "bpe.h" namespace "vkcom": cdef cppclass BaseEncoder: BaseEncoder(const string& model_path, int n_threads, Status* status) - Status encode_as_ids(const vector[string] &sentences, vector[vector[int]]* ids, bool bos, bool eos, bool reverse) const - Status encode_as_subwords(const vector[string]& sentences, vector[vector[string]]* subwords, bool bos, bool eos, bool reverse) const + Status encode_as_ids(const vector[string] &sentences, vector[vector[int]]* ids, bool bos, bool eos, bool reverse, double dropout_prob) const + Status encode_as_subwords(const vector[string]& sentences, vector[vector[string]]* subwords, bool bos, bool eos, bool reverse, double dropout_prob) const - Status encode_cli(string output_type, bool stream, bool bos, bool eos, bool reverse) const + Status encode_cli(string output_type, bool stream, bool bos, bool eos, bool reverse, double dropout_prob) const Status decode_cli(const unordered_set[int]* ignore_ids) const @@ -84,30 +84,31 @@ cdef class BPE: if status.code != 0: raise ValueError(status.message.decode()) - def encode(self, sentences, output_type, bos, eos, reverse): + def encode(self, sentences, output_type, bos, eos, reverse, dropout_prob): cdef vector[string] s cdef vector[vector[string]] ret_subwords cdef vector[vector[int]] ret_ids cdef Status status + if dropout_prob < 0 or dropout_prob > 1: + raise ValueError("dropout_prob value must be in the range [0, 1]. Current value of dropout_prob = " + str(dropout_prob)) if output_type == 'id': if isinstance(sentences, str): s = [sentences.encode()] - assert len(s) == 1 - status = self.encoder.encode_as_ids(s, &ret_ids, bos, eos, reverse) + status = self.encoder.encode_as_ids(s, &ret_ids, bos, eos, reverse, dropout_prob) if status.code != 0: raise ValueError(status.message.decode()) return ret_ids[0] assert isinstance(sentences, list) or isinstance(sentences, tuple) s = [x.encode() for x in sentences] - status = self.encoder.encode_as_ids(s, &ret_ids, bos, eos, reverse) + status = self.encoder.encode_as_ids(s, &ret_ids, bos, eos, reverse, dropout_prob) if status.code != 0: raise ValueError(status.message.decode()) return ret_ids elif output_type == 'subword': if isinstance(sentences, str): s = [sentences.encode()] - status = self.encoder.encode_as_subwords(s, &ret_subwords, bos, eos, reverse) + status = self.encoder.encode_as_subwords(s, &ret_subwords, bos, eos, reverse, dropout_prob) if status.code != 0: raise ValueError(status.message.decode()) assert len(ret_subwords) == 1 @@ -115,7 +116,7 @@ cdef class BPE: assert isinstance(sentences, list) or isinstance(sentences, tuple) s = [x.encode() for x in sentences] - status = self.encoder.encode_as_subwords(s, &ret_subwords, bos, eos, reverse) + status = self.encoder.encode_as_subwords(s, &ret_subwords, bos, eos, reverse, dropout_prob) if status.code != 0: raise ValueError(status.message.decode()) return [[piece.decode() for piece in sentence] for sentence in ret_subwords] @@ -163,8 +164,8 @@ cdef class BPE: cdef vector[string] vocab = self.encoder.vocabulary() return [token.decode() for token in vocab] - def encode_cli(self, output_type, stream, bos, eos, reverse): - cdef Status status = self.encoder.encode_cli(output_type.encode(), stream, bos, eos, reverse) + def encode_cli(self, output_type, stream, bos, eos, reverse, dropout_prob): + cdef Status status = self.encoder.encode_cli(output_type.encode(), stream, bos, eos, reverse, dropout_prob) if status.code != 0: raise ValueError(status.message.decode()) diff --git a/youtokentome/youtokentome.py b/youtokentome/youtokentome.py index 3cc61db..fd2b52f 100644 --- a/youtokentome/youtokentome.py +++ b/youtokentome/youtokentome.py @@ -48,6 +48,7 @@ def encode( bos: bool = False, eos: bool = False, reverse: bool = False, + dropout_prob: float = 0, ) -> Union[List[List[int]], List[List[str]]]: if not isinstance(output_type, OutputType): raise TypeError( @@ -62,6 +63,7 @@ def encode( bos=bos, eos=eos, reverse=reverse, + dropout_prob=dropout_prob, ) def vocab_size(self) -> int: diff --git a/youtokentome/yttm_cli.py b/youtokentome/yttm_cli.py index 0555e07..7e66879 100644 --- a/youtokentome/yttm_cli.py +++ b/youtokentome/yttm_cli.py @@ -26,7 +26,7 @@ def main(): @click.option( "--coverage", type=click.FLOAT, - help="Amount of characters covered by the model.", + help="Percentage of characters covered by the model.", default=1.0, show_default=True, ) @@ -98,7 +98,14 @@ def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eo @click.option( "--stream", is_flag=True, help="Process each line before reading the next one." ) -def encode(model, output_type, n_threads, bos, eos, reverse, stream): +@click.option( + "--dropout_prob", + type=click.FLOAT, + default=0, + show_default=True, + help="BPE-dropout probability (the probability of a merge being dropped)", +) +def encode(model, output_type, n_threads, bos, eos, reverse, stream, dropout_prob): """Encode text to ids or subwords.""" if n_threads < -1 or n_threads == 0: raise ValueError( @@ -107,7 +114,7 @@ def encode(model, output_type, n_threads, bos, eos, reverse, stream): ) bpe = yttmc.BPE(model, n_threads) - bpe.encode_cli(output_type, stream, bos, eos, reverse) + bpe.encode_cli(output_type, stream, bos, eos, reverse, dropout_prob) def validate_ignore_ids(ctx, param, value):