Skip to content
This repository has been archived by the owner on Apr 23, 2024. It is now read-only.

Commit

Permalink
Change the native code to use stream interfaces
Browse files Browse the repository at this point in the history
Signed-off-by: Vadim Markovtsev <[email protected]>
  • Loading branch information
vmarkovtsev committed Oct 10, 2019
1 parent 511ae6d commit 70377ad
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 44 deletions.
9 changes: 5 additions & 4 deletions youtokentome/cpp/bpe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ void rename_tokens(ska::flat_hash_map<uint32_t, uint32_t> &char2id,
}

BPEState learn_bpe_from_string(string &text_utf8, int n_tokens,
const string &output_file,
StreamWriter &output,
BpeConfig bpe_config) {
vector<std::thread> threads;
assert(bpe_config.n_threads >= 1 || bpe_config.n_threads == -1);
Expand Down Expand Up @@ -1294,8 +1294,8 @@ BPEState learn_bpe_from_string(string &text_utf8, int n_tokens,
rename_tokens(char2id, rules, bpe_config.special_tokens, n_tokens);

BPEState bpe_state = {char2id, rules, bpe_config.special_tokens};
bpe_state.dump(output_file);
std::cerr << "model saved to: " << output_file << std::endl;
bpe_state.dump(output);
std::cerr << "model saved to: " << output.name() << std::endl;
return bpe_state;
}

Expand Down Expand Up @@ -1450,7 +1450,8 @@ void train_bpe(const string &input_path, const string &model_path,
std::cerr << "reading file..." << std::endl;
auto data = fast_read_file_utf8(input_path);
std::cerr << "learning bpe..." << std::endl;
learn_bpe_from_string(data, vocab_size, model_path, bpe_config);
auto fout = StreamWriter.open(model_path);
learn_bpe_from_string(data, vocab_size, fout, bpe_config);
}

DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8,
Expand Down
2 changes: 1 addition & 1 deletion youtokentome/cpp/bpe.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class BaseEncoder {

explicit BaseEncoder(BPEState bpe_state, int _n_threads);

explicit BaseEncoder(const std::string& model_path, int n_threads);
explicit BaseEncoder(const StreamReader& model_path, int n_threads);

void fill_from_state();

Expand Down
78 changes: 60 additions & 18 deletions youtokentome/cpp/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,62 @@ namespace vkcom {
using std::string;
using std::vector;

class FileWriter : public StreamWriter {
public:
FileWriter(const std::string &file_name) {
this->file_name = file_name;
this->fout = std::ofstream(file_name, std::ios::out | std::ios::binary);
if (fout.fail()) {
std::cerr << "Can't open file: " << file_name << std::endl;
assert(false);
}
}

virtual int write(const char *buffer, int size) override {
return fout.write(buffer, size);
}

virtual std::string name() const noexcept override {
return file_name;
}

private:
std::string file_name;
std::ofstream fout;
};

class FileReader : public StreamReader {
public:
FileReader(const std::string &file_name) {
this->file_name = file_name;
this->fin = std::ifstream(file_name, std::ios::in | std::ios::binary);
if (fin.fail()) {
std::cerr << "Can't open file: " << file_name << std::endl;
assert(false);
}
}

virtual int read(const char *buffer, int size) override {
return fin.read(buffer, size);
}

virtual std::string name() const noexcept override {
return file_name;
}

private:
std::string file_name;
std::ifstream fin;
};

StreamWriter StreamWriter::open(const std::string &file_name) {
return FileWriter(file_name);
}

StreamReader StreamReader::open(const std::string &file_name) {
return FileReader(file_name);
}

template<typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
T bin_to_int(const char *val) {
uint32_t ret = static_cast<unsigned char>(val[0]);
Expand All @@ -31,7 +87,7 @@ std::unique_ptr<char[]> int_to_bin(T val) {
return std::move(ret);
}

void SpecialTokens::dump(std::ofstream &fout) {
void SpecialTokens::dump(StreamWriter &fout) {
std::unique_ptr<char[]> unk_id_ptr(int_to_bin(unk_id)),
pad_id_ptr(int_to_bin(pad_id)),
bos_id_ptr(int_to_bin(bos_id)),
Expand All @@ -42,7 +98,7 @@ void SpecialTokens::dump(std::ofstream &fout) {
fout.write(eos_id_ptr.get(), 4);
}

void SpecialTokens::load(std::ifstream &fin) {
void SpecialTokens::load(StreamReader &fin) {
char unk_id_bs[4], pad_id_bs[4], bos_id_bs[4], eos_id_bs[4];
fin.read(unk_id_bs, 4);
fin.read(pad_id_bs, 4);
Expand Down Expand Up @@ -85,13 +141,7 @@ bool BPE_Rule::operator==(const BPE_Rule &other) const {

BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {}

void BPEState::dump(const string &file_name) {
std::ofstream fout(file_name, std::ios::out | std::ios::binary);
if (fout.fail()) {
std::cerr << "Can't open file: " << file_name << std::endl;
assert(false);
}

void BPEState::dump(StreamWriter &fout) {
std::unique_ptr<char[]> char2id_ptr(int_to_bin(char2id.size())),
rules_ptr(int_to_bin(rules.size()));
fout.write(char2id_ptr.get(), 4);
Expand All @@ -115,18 +165,11 @@ void BPEState::dump(const string &file_name) {
fout.write(rule_ptr.get(), 4);
}
special_tokens.dump(fout);
fout.close();
}

void BPEState::load(const string &file_name) {
void BPEState::load(StreamReader &fin) {
char2id.clear();
rules.clear();
std::ifstream fin(file_name, std::ios::in | std::ios::binary);
if (fin.fail()) {
std::cerr << "Error. Can not open file with model: " << file_name
<< std::endl;
exit(EXIT_FAILURE);
}
char n_bs[4], m_bs[4];
fin.read(n_bs, 4);
fin.read(m_bs, 4);
Expand Down Expand Up @@ -161,7 +204,6 @@ void BPEState::load(const string &file_name) {
rules.emplace_back(std::get<0>(rules_xyz[i]), std::get<1>(rules_xyz[i]), std::get<2>(rules_xyz[i]));
}
special_tokens.load(fin);
fin.close();
}

BpeConfig::BpeConfig(double _character_coverage, int _n_threads,
Expand Down
22 changes: 18 additions & 4 deletions youtokentome/cpp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@
namespace vkcom {
const uint32_t SPACE_TOKEN = 9601;

struct StreamWriter {
virtual int write(const char *buffer, int size) = 0;
virtual std::string name() const noexcept = 0;

static StreamWriter open(const std::string &file_name);
};

struct StreamReader {
virtual int read(const char *buffer, int size) = 0;
virtual std::string name() const noexcept = 0;

static StreamReader open(const std::string &file_name);
};

struct BPE_Rule {
// x + y -> z
uint32_t x{0};
Expand All @@ -31,9 +45,9 @@ struct SpecialTokens {

SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id);

void dump(std::ofstream &fout);
void dump(StreamWriter &fout);

void load(std::ifstream &fin);
void load(StreamReader &fin);

uint32_t max_id() const;

Expand All @@ -58,9 +72,9 @@ struct BPEState {
std::vector<BPE_Rule> rules;
SpecialTokens special_tokens;

void dump(const std::string &file_name);
void dump(StreamWriter &fout);

void load(const std::string &file_name);
void load(StreamReader &fin);
};

struct DecodeResult {
Expand Down
67 changes: 50 additions & 17 deletions youtokentome/youtokentome.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
from enum import Enum
from typing import List, Union
from functools import wraps
from typing import BinaryIO, List, Optional, Union

import _youtokentome_cython


class OutputType(Enum):
ID = 1
SUBWORD = 2


class BPE:
def __init__(self, model: str, n_threads: int = -1):
self.bpe_cython = _youtokentome_cython.BPE(
model_path=model, n_threads=n_threads
)
def __init__(self, model: Union[str, BinaryIO], n_threads: int = -1):
own_obj = isinstance(model, str)
if own_obj:
model = open(model, "rb")
try:
self.bpe_cython = _youtokentome_cython.BPE(
model_fobj=model, n_threads=n_threads
)
finally:
if own_obj:
model.close()

@staticmethod
def train(
data: str,
model: str,
model: Optional[Union[str, BinaryIO]],
vocab_size: int,
coverage: float = 1.0,
n_threads: int = -1,
Expand All @@ -25,17 +35,24 @@ def train(
bos_id: int = 2,
eos_id: int = 3,
) -> "BPE":
_youtokentome_cython.BPE.train(
data=data,
model=model,
vocab_size=vocab_size,
n_threads=n_threads,
coverage=coverage,
pad_id=pad_id,
unk_id=unk_id,
bos_id=bos_id,
eos_id=eos_id,
)
own_obj = isinstance(model, str)
if own_obj:
model = open(model, "wb")
try:
_youtokentome_cython.BPE.train(
data=data,
model=model,
vocab_size=vocab_size,
n_threads=n_threads,
coverage=coverage,
pad_id=pad_id,
unk_id=unk_id,
bos_id=bos_id,
eos_id=eos_id,
)
finally:
if own_obj:
model.close()

return BPE(model=model, n_threads=n_threads)

Expand All @@ -61,6 +78,22 @@ def encode(
reverse=reverse,
)

def save(self, where: Union[str, BinaryIO]):
"""
Write the model to FS or any writeable file object.
:param where: FS path or writeable file object.
:return: None
"""
own_obj = isinstance(where, str)
if own_obj:
where = open(where, "wb")
try:
self.bpe_cython.save(where=where)
finally:
if own_obj:
where.close()

def vocab_size(self) -> int:
return self.bpe_cython.vocab_size()

Expand Down

0 comments on commit 70377ad

Please sign in to comment.