Skip to content
This repository has been archived by the owner on Apr 23, 2024. It is now read-only.

Commit

Permalink
Change the model format to binary
Browse files Browse the repository at this point in the history
We are using big endian numbers under the hood.

Rule's x, y and z are written by plane, not interleaved.

Signed-off-by: Vadim Markovtsev <[email protected]>
  • Loading branch information
vmarkovtsev committed Oct 10, 2019
1 parent f5f4bf3 commit 511ae6d
Showing 1 changed file with 90 additions and 19 deletions.
109 changes: 90 additions & 19 deletions youtokentome/cpp/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,56 @@
#include <cassert>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>


namespace vkcom {
using std::string;
using std::vector;

template<typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
T bin_to_int(const char *val) {
uint32_t ret = static_cast<unsigned char>(val[0]);
ret |= static_cast<uint32_t>(static_cast<unsigned char>(val[1])) << 8;
ret |= static_cast<uint32_t>(static_cast<unsigned char>(val[2])) << 16;
ret |= static_cast<uint32_t>(static_cast<unsigned char>(val[3])) << 24;
return ret;
}

template<typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
std::unique_ptr<char[]> int_to_bin(T val) {
auto u32 = static_cast<uint32_t>(val);
std::unique_ptr<char[]> ret(new char[4]);
ret[0] = u32 & 0xFF;
ret[1] = (u32 >> 8) & 0xFF;
ret[2] = (u32 >> 16) & 0xFF;
ret[3] = (u32 >> 24); // no need for & 0xFF
return std::move(ret);
}

void SpecialTokens::dump(std::ofstream &fout) {
fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id
<< std::endl;
std::unique_ptr<char[]> unk_id_ptr(int_to_bin(unk_id)),
pad_id_ptr(int_to_bin(pad_id)),
bos_id_ptr(int_to_bin(bos_id)),
eos_id_ptr(int_to_bin(eos_id));
fout.write(unk_id_ptr.get(), 4);
fout.write(pad_id_ptr.get(), 4);
fout.write(bos_id_ptr.get(), 4);
fout.write(eos_id_ptr.get(), 4);
}

void SpecialTokens::load(std::ifstream &fin) {
fin >> unk_id >> pad_id >> bos_id >> eos_id;
char unk_id_bs[4], pad_id_bs[4], bos_id_bs[4], eos_id_bs[4];
fin.read(unk_id_bs, 4);
fin.read(pad_id_bs, 4);
fin.read(bos_id_bs, 4);
fin.read(eos_id_bs, 4);
this->unk_id = bin_to_int<int>(unk_id_bs);
this->pad_id = bin_to_int<int>(pad_id_bs);
this->bos_id = bin_to_int<int>(bos_id_bs);
this->eos_id = bin_to_int<int>(eos_id_bs);
}

uint32_t SpecialTokens::max_id() const {
Expand Down Expand Up @@ -50,18 +86,33 @@ bool BPE_Rule::operator==(const BPE_Rule &other) const {
BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {}

void BPEState::dump(const string &file_name) {
std::ofstream fout(file_name, std::ios::out);
std::ofstream fout(file_name, std::ios::out | std::ios::binary);
if (fout.fail()) {
std::cerr << "Can't open file: " << file_name << std::endl;
assert(false);
}
fout << char2id.size() << " " << rules.size() << std::endl;
for (auto s : char2id) {
fout << s.first << " " << s.second << std::endl;
}

for (auto rule : rules) {
fout << rule.x << " " << rule.y << " " << rule.z << std::endl;
std::unique_ptr<char[]> char2id_ptr(int_to_bin(char2id.size())),
rules_ptr(int_to_bin(rules.size()));
fout.write(char2id_ptr.get(), 4);
fout.write(rules_ptr.get(), 4);
for (auto &s : char2id) {
std::unique_ptr<char[]> first_ptr(int_to_bin(s.first)),
second_ptr(int_to_bin(s.second));
fout.write(first_ptr.get(), 4);
fout.write(second_ptr.get(), 4);
}
for (auto &rule : rules) {
std::unique_ptr<char[]> rule_ptr(int_to_bin(rule.x));
fout.write(rule_ptr.get(), 4);
}
for (auto &rule : rules) {
std::unique_ptr<char[]> rule_ptr(int_to_bin(rule.y));
fout.write(rule_ptr.get(), 4);
}
for (auto &rule : rules) {
std::unique_ptr<char[]> rule_ptr(int_to_bin(rule.z));
fout.write(rule_ptr.get(), 4);
}
special_tokens.dump(fout);
fout.close();
Expand All @@ -70,24 +121,44 @@ void BPEState::dump(const string &file_name) {
void BPEState::load(const string &file_name) {
char2id.clear();
rules.clear();
std::ifstream fin(file_name, std::ios::in);
std::ifstream fin(file_name, std::ios::in | std::ios::binary);
if (fin.fail()) {
std::cerr << "Error. Can not open file with model: " << file_name
<< std::endl;
exit(EXIT_FAILURE);
}
int n, m;
fin >> n >> m;
char n_bs[4], m_bs[4];
fin.read(n_bs, 4);
fin.read(m_bs, 4);
auto n = bin_to_int<int>(n_bs);
auto m = bin_to_int<int>(m_bs);
for (int i = 0; i < n; i++) {
uint32_t inner_id;
uint32_t utf32_id;
fin >> inner_id >> utf32_id;
char inner_id_bs[4], utf32_id_bs[4];
fin.read(inner_id_bs, 4);
fin.read(utf32_id_bs, 4);
auto inner_id = bin_to_int<uint32_t>(inner_id_bs);
auto utf32_id = bin_to_int<uint32_t>(utf32_id_bs);
char2id[inner_id] = utf32_id;
}
std::vector<std::tuple<uint32_t, uint32_t, uint32_t>> rules_xyz(m);
for (int j = 0; j < 3; j++) {
for (int i = 0; i < m; i++) {
char val[4];
fin.read(val, 4);
uint32_t *element;
switch (j) {
case 0:
element = &std::get<0>(rules_xyz[i]);
case 1:
element = &std::get<1>(rules_xyz[i]);
case 2:
element = &std::get<2>(rules_xyz[i]);
}
*element = bin_to_int<uint32_t>(val);
}
}
for (int i = 0; i < m; i++) {
uint32_t x, y, z;
fin >> x >> y >> z;
rules.emplace_back(x, y, z);
rules.emplace_back(std::get<0>(rules_xyz[i]), std::get<1>(rules_xyz[i]), std::get<2>(rules_xyz[i]));
}
special_tokens.load(fin);
fin.close();
Expand Down

0 comments on commit 511ae6d

Please sign in to comment.