-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathbpe.cpp
123 lines (112 loc) · 2.72 KB
/
bpe.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <stdio.h>
#include <stdint.h>
#include <unordered_map>
#include "bpe.h"
BPEDecoder::BPEDecoder() {}
BPEDecoder::~BPEDecoder() {}
bool BPEDecoder::Init(const char* vocab_path) {
FILE *fp = fopen(vocab_path, "rb");
if (!fp) {
return false;
}
// each entry is just <length byte> <string>
while (!feof(fp)) {
uint8_t len;
if (fread(&len, 1, 1, fp) != 1) {
break;
}
char buf[256];
if (fread(buf, 1, len, fp) != len) {
break;
}
buf[len] = 0;
vocab_.push_back(buf);
}
return true;
}
int BPEDecoder::Decode(const int* tokens, int ntokens, char* outbuf, int outbuf_size) {
int j = 0;
for (int i = 0; i < ntokens; i++) {
if (j >= outbuf_size) {
break;
}
if (tokens[i] == -1) {
outbuf[j++] = '(';
outbuf[j++] = '?';
outbuf[j++] = ')';
continue;
}
if (tokens[i] < 0 || tokens[i] >= vocab_.size()) {
break;
}
int len = vocab_[tokens[i]].size();
if (j + len >= outbuf_size) {
break;
}
const char* s = vocab_[tokens[i]].c_str();
for (int k = 0; k < len; k++) {
outbuf[j++] = *s++;
}
}
outbuf[j] = 0;
return j;
}
struct BPETrieNode {
int token_length = -1; // -1 indicates there is no token ending at this node
int token_id;
std::unordered_map<char, BPETrieNode*> children;
~BPETrieNode() {
for (auto it = children.begin(); it != children.end(); it++) {
delete it->second;
}
}
};
BPEEncoder::BPEEncoder() {
root_ = new BPETrieNode();
}
BPEEncoder::~BPEEncoder() {
delete root_;
}
bool BPEEncoder::Init(const std::vector<std::string>& vocab) {
for (int i = 0; i < vocab.size(); i++) {
auto token = vocab[i];
BPETrieNode* node = root_;
for (size_t i = 0; i < token.size(); i++) {
char key = token[i];
if (node->children.count(key) == 0) {
node->children[key] = new BPETrieNode();
}
node = node->children[key];
}
node->token_length = token.size();
node->token_id = i;
}
return true;
}
const char* BPEEncoder::Encode(const char *string, int *outbuf, int outbuf_size, int *ntokens) {
*ntokens = 0;
while(*string && *ntokens < outbuf_size) {
BPETrieNode* node = root_;
int last_token_length = -1;
int last_token_id = -1;
for (size_t i = 0; string[i]; i++) {
char key = string[i];
if (node->children.count(key) == 0) {
break;
}
node = node->children[key];
if (node->token_length != -1) {
last_token_length = node->token_length;
last_token_id = node->token_id;
}
}
if (last_token_length == -1) {
return string;
} else {
*outbuf++ = last_token_id;
string += last_token_length;
(*ntokens)++;
}
}
return string;
}