Skip to content
This repository has been archived by the owner on Apr 23, 2024. It is now read-only.

[WIP] File objects #21

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Change the native code to use stream interfaces
Signed-off-by: Vadim Markovtsev <vadim@sourced.tech>
vmarkovtsev committed Oct 11, 2019
commit b5ca533c7bc91df52e5a3557d618c25677fc054d
9 changes: 5 additions & 4 deletions youtokentome/cpp/bpe.cpp
Original file line number Diff line number Diff line change
@@ -865,7 +865,7 @@ void rename_tokens(ska::flat_hash_map<uint32_t, uint32_t> &char2id,
}

BPEState learn_bpe_from_string(string &text_utf8, int n_tokens,
const string &output_file,
StreamWriter &output,
BpeConfig bpe_config) {
vector<std::thread> threads;
assert(bpe_config.n_threads >= 1 || bpe_config.n_threads == -1);
@@ -1294,8 +1294,8 @@ BPEState learn_bpe_from_string(string &text_utf8, int n_tokens,
rename_tokens(char2id, rules, bpe_config.special_tokens, n_tokens);

BPEState bpe_state = {char2id, rules, bpe_config.special_tokens};
bpe_state.dump(output_file);
std::cerr << "model saved to: " << output_file << std::endl;
bpe_state.dump(output);
std::cerr << "model saved to: " << output.name() << std::endl;
return bpe_state;
}

@@ -1450,7 +1450,8 @@ void train_bpe(const string &input_path, const string &model_path,
std::cerr << "reading file..." << std::endl;
auto data = fast_read_file_utf8(input_path);
std::cerr << "learning bpe..." << std::endl;
learn_bpe_from_string(data, vocab_size, model_path, bpe_config);
auto fout = StreamWriter.open(model_path);
learn_bpe_from_string(data, vocab_size, fout, bpe_config);
}

DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8,
2 changes: 1 addition & 1 deletion youtokentome/cpp/bpe.h
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ class BaseEncoder {

explicit BaseEncoder(BPEState bpe_state, int _n_threads);

explicit BaseEncoder(const std::string& model_path, int n_threads);
explicit BaseEncoder(const StreamReader& model_path, int n_threads);

void fill_from_state();

78 changes: 60 additions & 18 deletions youtokentome/cpp/utils.cpp
Original file line number Diff line number Diff line change
@@ -11,6 +11,62 @@ namespace vkcom {
using std::string;
using std::vector;

class FileWriter : public StreamWriter {
public:
FileWriter(const std::string &file_name) {
this->file_name = file_name;
this->fout = std::ofstream(file_name, std::ios::out | std::ios::binary);
if (fout.fail()) {
std::cerr << "Can't open file: " << file_name << std::endl;
assert(false);
}
}

virtual int write(const char *buffer, int size) override {
return fout.write(buffer, size);
}

virtual std::string name() const noexcept override {
return file_name;
}

private:
std::string file_name;
std::ofstream fout;
};

class FileReader : public StreamReader {
public:
FileReader(const std::string &file_name) {
this->file_name = file_name;
this->fin = std::ifstream(file_name, std::ios::in | std::ios::binary);
if (fin.fail()) {
std::cerr << "Can't open file: " << file_name << std::endl;
assert(false);
}
}

virtual int read(const char *buffer, int size) override {
return fin.read(buffer, size);
}

virtual std::string name() const noexcept override {
return file_name;
}

private:
std::string file_name;
std::ifstream fin;
};

StreamWriter StreamWriter::open(const std::string &file_name) {
return FileWriter(file_name);
}

StreamReader StreamReader::open(const std::string &file_name) {
return FileReader(file_name);
}

template<typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
T bin_to_int(const char *val) {
uint32_t ret = static_cast<unsigned char>(val[0]);
@@ -31,7 +87,7 @@ std::unique_ptr<char[]> int_to_bin(T val) {
return std::move(ret);
}

void SpecialTokens::dump(std::ofstream &fout) {
void SpecialTokens::dump(StreamWriter &fout) {
std::unique_ptr<char[]> unk_id_ptr(int_to_bin(unk_id)),
pad_id_ptr(int_to_bin(pad_id)),
bos_id_ptr(int_to_bin(bos_id)),
@@ -42,7 +98,7 @@ void SpecialTokens::dump(std::ofstream &fout) {
fout.write(eos_id_ptr.get(), 4);
}

void SpecialTokens::load(std::ifstream &fin) {
void SpecialTokens::load(StreamReader &fin) {
char unk_id_bs[4], pad_id_bs[4], bos_id_bs[4], eos_id_bs[4];
fin.read(unk_id_bs, 4);
fin.read(pad_id_bs, 4);
@@ -85,13 +141,7 @@ bool BPE_Rule::operator==(const BPE_Rule &other) const {

BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {}

void BPEState::dump(const string &file_name) {
std::ofstream fout(file_name, std::ios::out | std::ios::binary);
if (fout.fail()) {
std::cerr << "Can't open file: " << file_name << std::endl;
assert(false);
}

void BPEState::dump(StreamWriter &fout) {
std::unique_ptr<char[]> char2id_ptr(int_to_bin(char2id.size())),
rules_ptr(int_to_bin(rules.size()));
fout.write(char2id_ptr.get(), 4);
@@ -115,18 +165,11 @@ void BPEState::dump(const string &file_name) {
fout.write(rule_ptr.get(), 4);
}
special_tokens.dump(fout);
fout.close();
}

void BPEState::load(const string &file_name) {
void BPEState::load(StreamReader &fin) {
char2id.clear();
rules.clear();
std::ifstream fin(file_name, std::ios::in | std::ios::binary);
if (fin.fail()) {
std::cerr << "Error. Can not open file with model: " << file_name
<< std::endl;
exit(EXIT_FAILURE);
}
char n_bs[4], m_bs[4];
fin.read(n_bs, 4);
fin.read(m_bs, 4);
@@ -161,7 +204,6 @@ void BPEState::load(const string &file_name) {
rules.emplace_back(std::get<0>(rules_xyz[i]), std::get<1>(rules_xyz[i]), std::get<2>(rules_xyz[i]));
}
special_tokens.load(fin);
fin.close();
}

BpeConfig::BpeConfig(double _character_coverage, int _n_threads,
24 changes: 20 additions & 4 deletions youtokentome/cpp/utils.h
Original file line number Diff line number Diff line change
@@ -8,6 +8,22 @@
namespace vkcom {
const uint32_t SPACE_TOKEN = 9601;

struct StreamWriter {
virtual int write(const char *buffer, int size) = 0;
virtual std::string name() const noexcept = 0;
virtual ~StreamWriter() = default;

static StreamWriter open(const std::string &file_name);
};

struct StreamReader {
virtual int read(const char *buffer, int size) = 0;
virtual std::string name() const noexcept = 0;
virtual ~StreamReader() = default;

static StreamReader open(const std::string &file_name);
};

struct BPE_Rule {
// x + y -> z
uint32_t x{0};
@@ -31,9 +47,9 @@ struct SpecialTokens {

SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id);

void dump(std::ofstream &fout);
void dump(StreamWriter &fout);

void load(std::ifstream &fin);
void load(StreamReader &fin);

uint32_t max_id() const;

@@ -58,9 +74,9 @@ struct BPEState {
std::vector<BPE_Rule> rules;
SpecialTokens special_tokens;

void dump(const std::string &file_name);
void dump(StreamWriter &fout);

void load(const std::string &file_name);
void load(StreamReader &fin);
};

struct DecodeResult {
67 changes: 50 additions & 17 deletions youtokentome/youtokentome.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
from enum import Enum
from typing import List, Union
from functools import wraps
from typing import BinaryIO, List, Optional, Union

import _youtokentome_cython


class OutputType(Enum):
ID = 1
SUBWORD = 2


class BPE:
def __init__(self, model: str, n_threads: int = -1):
self.bpe_cython = _youtokentome_cython.BPE(
model_path=model, n_threads=n_threads
)
def __init__(self, model: Union[str, BinaryIO], n_threads: int = -1):
own_obj = isinstance(model, str)
if own_obj:
model = open(model, "rb")
try:
self.bpe_cython = _youtokentome_cython.BPE(
model_fobj=model, n_threads=n_threads
)
finally:
if own_obj:
model.close()

@staticmethod
def train(
data: str,
model: str,
model: Optional[Union[str, BinaryIO]],
vocab_size: int,
coverage: float = 1.0,
n_threads: int = -1,
@@ -25,17 +35,24 @@ def train(
bos_id: int = 2,
eos_id: int = 3,
) -> "BPE":
_youtokentome_cython.BPE.train(
data=data,
model=model,
vocab_size=vocab_size,
n_threads=n_threads,
coverage=coverage,
pad_id=pad_id,
unk_id=unk_id,
bos_id=bos_id,
eos_id=eos_id,
)
own_obj = isinstance(model, str)
if own_obj:
model = open(model, "wb")
try:
_youtokentome_cython.BPE.train(
data=data,
model=model,
vocab_size=vocab_size,
n_threads=n_threads,
coverage=coverage,
pad_id=pad_id,
unk_id=unk_id,
bos_id=bos_id,
eos_id=eos_id,
)
finally:
if own_obj:
model.close()

return BPE(model=model, n_threads=n_threads)

@@ -61,6 +78,22 @@ def encode(
reverse=reverse,
)

def save(self, where: Union[str, BinaryIO]):
"""
Write the model to FS or any writeable file object.

:param where: FS path or writeable file object.
:return: None
"""
own_obj = isinstance(where, str)
if own_obj:
where = open(where, "wb")
try:
self.bpe_cython.save(where=where)
finally:
if own_obj:
where.close()

def vocab_size(self) -> int:
return self.bpe_cython.vocab_size()