From ca6acc0ac5b1ef326480d938e76883e22ba415f0 Mon Sep 17 00:00:00 2001 From: Gautier Dagan Date: Mon, 13 Nov 2023 17:37:40 +0000 Subject: [PATCH] maturin working --- Cargo.toml | 7 +- README.md | 2 +- pyproject.toml | 16 ++ src/lib.rs | 16 ++ src/main.rs | 167 +++++++++++++ src/tiktoken.rs | 608 ++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 814 insertions(+), 2 deletions(-) create mode 100644 pyproject.toml create mode 100644 src/lib.rs create mode 100644 src/main.rs create mode 100644 src/tiktoken.rs diff --git a/Cargo.toml b/Cargo.toml index bc552c7..648b540 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,11 +3,16 @@ name = "bpeasy" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "bpeasy" +crate-type = ["cdylib"] [dependencies] +pyo3 = { version = "0.19.0", features = ["extension-module"] } regex = "1.5.4" [[bin]] name = "bpeasy" path = "main.rs" + + diff --git a/README.md b/README.md index 543c72d..93bb025 100644 --- a/README.md +++ b/README.md @@ -1 +1 @@ -# bpeasy \ No newline at end of file +# bpeasy diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..68912ac --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["maturin>=1.3,<2.0"] +build-backend = "maturin" + +[project] +name = "bpeasy" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dynamic = ["version"] + +[tool.maturin] +features = ["pyo3/extension-module"] diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..429be80 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,16 @@ +use pyo3::prelude::*; + +/// Formats the sum of two numbers as string. +#[pyfunction] +fn sum_as_string(a: usize, b: usize) -> PyResult { + Ok((a + b).to_string()) +} + +/// A Python module implemented in Rust. The name of this function must match +/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to +/// import the module. +#[pymodule] +fn bpeasy(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; + Ok(()) +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..809be71 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,167 @@ +extern crate regex; +use pyo3::prelude::*; +use regex::Regex; +use std::collections::HashMap; + +/// Formats the sum of two numbers as string. +#[pyfunction] +fn sum_as_string(a: usize, b: usize) -> PyResult { + Ok((a + b).to_string()) +} + +/// A Python module implemented in Rust. +#[pymodule] +fn bpeasy(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; + Ok(()) +} + +fn tokenize(text: &str) -> Vec { + // regex splits + let re = Regex::new(r"([^\s]+)|(\s+)").unwrap(); + re.find_iter(text) + .map(|mat| mat.as_str().to_string()) + .collect() +} + +fn convert_to_tokenized_bytes(tokenized_text: Vec) -> Vec>> { + let mut tokenized_bytes: Vec>> = Vec::new(); + for token in tokenized_text { + let mut tokenized_byte: Vec> = Vec::new(); + for byte in token.bytes() { + tokenized_byte.push(vec![byte]); + } + tokenized_bytes.push(tokenized_byte); + } + tokenized_bytes +} + +fn initialize_vocab_bytes() -> HashMap, u64> { + let mut vocab: HashMap, u64> = HashMap::new(); + for i in 0..255 { + vocab.insert(vec![i], i as u64); + } + vocab +} + +fn get_most_frequent_pair(tokenized_bytes: &mut Vec>>) -> Option<(Vec, Vec)> { + /* + Calculate frequencies for each pair of bytes in all sentences and words + Return the most frequent pair of bytes + */ + + let mut pair_freqs: HashMap<(Vec, Vec), u128> = HashMap::new(); + + // Calculate frequencies for each pair of bytes in all sentences and words + for sentence in tokenized_bytes { + for word in sentence.windows(2) { + if let [a, b] = word { + *pair_freqs.entry((a.to_vec(), b.to_vec())).or_insert(0) += 1; + } + } + } + // println!("{:?}", pair_freqs); + let most_frequent_pair = pair_freqs.iter().max_by_key(|&(_, count)| count); + println!("Most frequent pair: {:?}", most_frequent_pair); + if most_frequent_pair.is_none() { + return None; + } + let ((ref left, ref right), _count) = most_frequent_pair.unwrap(); + Some((left.clone(), right.clone())) +} + +fn merge_frequent_pair(tokenized_bytes: &mut Vec>>, left: Vec, right: Vec) { + // Merge the most frequent pair in all sentences and words + for sentence in tokenized_bytes.iter_mut() { + let mut i = 0; + while i < sentence.len() - 1 { + // Check if the current and next token form the most frequent pair + if sentence[i] == left.clone() && sentence[i + 1] == right.clone() { + // Merge the pair and replace the first element with the merged pair + let merged = [&sentence[i][..], &sentence[i + 1][..]].concat(); + sentence[i] = merged; + // Remove the second element of the pair + sentence.remove(i + 1); + // Do not increment i, as we want to check the next pair starting from the current position + } else { + i += 1; // Move to the next token + } + } + } +} + +// fn print_vocab_bytes(vocab: &HashMap, u64>) { +// // sort by value +// let mut sorted_vocab: Vec<_> = vocab.iter().collect(); +// sorted_vocab.sort_by(|a, b| a.1.cmp(b.1)); +// for (key, value) in sorted_vocab { +// // try to convert to string +// let key_str = String::from_utf8_lossy(key); +// println!("{:?}: {}", key_str, value); +// } +// } + +fn build_bpe_vocab( + mut tokenized_bytes: Vec>>, + vocab_size: usize, +) -> HashMap, u64> { + let mut vocab: HashMap, u64> = initialize_vocab_bytes(); + + println!("{:?}", vocab); + + let mut num_token_added = 0; + while num_token_added < vocab_size { + println!("Iteration: {}", num_token_added); + + let most_frequent_pair = get_most_frequent_pair(&mut tokenized_bytes); + if most_frequent_pair.is_none() { + break; + } + let (left, right) = most_frequent_pair.unwrap(); + + // Merge the most frequent pair in all sentences and words + merge_frequent_pair(&mut tokenized_bytes, left.clone(), right.clone()); + + let mut token = left.clone(); // Clone the first token + token.extend(right); // Extend with the second token + // Now, combined_token contains the merged pair + println!("Combined token: {:?}", token); + + // combine pair into a single token + let token_str = String::from_utf8_lossy(&token); + println!("Token added: {:?}", token_str); + vocab.insert(token, vocab.len() as u64); + + num_token_added += 1; + } + // print_vocab_bytes(&vocab); + vocab +} + +#[cfg(test)] + +mod tests { + use super::*; + + #[test] + fn test_tokenize() { + let text = "Your text data here"; + let tokens = tokenize(text); + assert_eq!(tokens, vec!["Your", " ", "text", " ", "data", " ", "here"]); + } + + #[test] + fn test_all() { + let text: &str = "\tYou hear £ £ £ here"; + + let tokens = tokenize(text); + println!("{:?}", tokens); + let tokenized_bytes = convert_to_tokenized_bytes(tokens); + println!("{:?}", tokenized_bytes); + + let vocab_size = 10; + let bpe_vocab = build_bpe_vocab(tokenized_bytes, vocab_size); + println!("{:?}", bpe_vocab); + // Output or use the encoded text + } +} diff --git a/src/tiktoken.rs b/src/tiktoken.rs new file mode 100644 index 0000000..3d93065 --- /dev/null +++ b/src/tiktoken.rs @@ -0,0 +1,608 @@ +// This check is new and seems buggy (possibly with PyO3 interaction) +#![allow(clippy::borrow_deref_ref)] + +use std::collections::HashSet; +use std::thread; + +use fancy_regex::Regex; +use pyo3::exceptions; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyList, PyTuple}; +use pyo3::PyResult; +use rustc_hash::FxHashMap as HashMap; + +fn _byte_pair_merge( + piece: &[u8], + ranks: &HashMap, usize>, + f: impl Fn(std::ops::Range) -> T, +) -> Vec { + // This is a vector of (start, rank). + // The rank is of the byte pair starting at position start. + // The rank of the last item in the vector is not a valid value. + let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect(); + + let get_rank = { + #[inline(always)] + |parts: &Vec<(usize, usize)>, start_idx: usize, skip: usize| { + if (start_idx + skip + 2) < parts.len() { + ranks + .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0]) + .copied() + } else { + None + } + } + }; + + // We look up the ranks once in the beginning and iteratively update + // them during each merge, which reduces the number of rank lookups. + for i in 0..parts.len() - 2 { + match get_rank(&parts, i, 0) { + Some(rank) => { + // usize::MAX is a sentinel value and cannot be a valid rank + debug_assert!(rank != usize::MAX); + parts[i].1 = rank; + } + None => { + continue; + } + }; + } + + // If you have n parts and m merges, this does O(mn) work. + // We could do something with a heap and do O(m log n) work. + // It is important to consider that n is often small (<100), and as such + // the cache-locality benefits outweigh the algorithmic complexity downsides + // of the `parts` vector data structure above. + + // Note that we hash bytes, not token pairs. As long as we train BPE the way we + // currently do, this is equivalent. An easy way to break this would be to decouple + // merge priority from token index or to prevent specific token merges. + loop { + if parts.len() == 1 { + break; + } + + // usize::MAX is a sentinel rank value allowing us to + // take the min more quickly + let mut min_rank: (usize, usize) = (usize::MAX, 0); + for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { + if rank < min_rank.0 { + min_rank = (rank, i); + } + } + + if min_rank.0 != usize::MAX { + let i = min_rank.1; + + // NOTE: We are about to remove parts[i + 1]. We do not do it + // yet because there are cache-locality benefits to updating + // parts[i] and parts[i-1] before removing, which could thrash + // the cache. Thus, we update the rank calculation by skipping over + // parts[i + 1], by invoking `get_rank!` with `skip = 1`. + parts[i].1 = get_rank(&parts, i, 1).unwrap_or(usize::MAX); + if i > 0 { + parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(usize::MAX); + } + + parts.remove(i + 1); + } else { + break; + } + } + let mut out: Vec = Vec::with_capacity(parts.len() - 1); + for i in 0..parts.len() - 1 { + out.push(f(parts[i].0..parts[i + 1].0)); + } + out +} + +pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, usize>) -> Vec { + if piece.len() == 1 { + return vec![ranks[piece]]; + } + _byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]]) +} + +pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, usize>) -> Vec<&'a [u8]> { + if piece.len() == 1 { + return vec![piece]; + } + _byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end]) +} + +// Various performance notes: +// +// Regex +// ===== +// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy +// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than +// the usual regex we use. +// +// However, given that we're using a regex parse-able by `regex`, there isn't much difference +// between using the `regex` crate and using the `fancy_regex` crate. +// +// There is an important interaction between threading, `regex` and `fancy_regex`. +// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on +// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain +// old `regex`, we don't hit this, because `find_iter` has a different code path. +// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md +// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for +// each thread. +// +// Threading +// ========= +// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL. +// So goodbye `rayon`! Let thread count etc be in control of our Python users. +// +// Caching +// ======= +// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`. +// Originally, we had one too! Without it, we were only vaguely faster than Python. +// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance +// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect +// multi-threaded performance even when I only had readers (maybed I messed something up?). +// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache! +// These are exactly the set or merges that are likely to be hot. And now we don't have to think +// about interior mutability, memory use, or cloning. +// +// Hashing +// ======= +// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win? +// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made +// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. + +use std::num::NonZeroU64; +pub struct FakeThreadId(NonZeroU64); + +fn hash_current_thread() -> usize { + // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter + // that works great for our use case of avoiding collisions in our array. Unfortunately, + // it's private. However, there are only so many ways you can layout a u64, so just transmute + // https://github.com/rust-lang/rust/issues/67939 + const _: [u8; 8] = [0; std::mem::size_of::()]; + const _: [u8; 8] = [0; std::mem::size_of::()]; + let x = unsafe { + std::mem::transmute::(thread::current().id()).0 + }; + u64::from(x) as usize +} + +const MAX_NUM_THREADS: usize = 128; +#[pyclass] +struct CoreBPE { + encoder: HashMap, usize>, + special_tokens_encoder: HashMap, + decoder: HashMap>, + special_tokens_decoder: HashMap>, + regex_tls: Vec, + special_regex_tls: Vec, + sorted_token_bytes: Vec>, +} + +impl CoreBPE { + fn _get_tl_regex(&self) -> &Regex { + // See performance notes above for what this is about + // It's also a little janky, please make a better version of it! + // However, it's nice that this doesn't leak memory to short-lived threads + &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS] + } + + fn _get_tl_special_regex(&self) -> &Regex { + &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] + } + + fn _decode_native(&self, tokens: &[usize]) -> Vec { + let mut ret = Vec::with_capacity(tokens.len() * 2); + for token in tokens { + let token_bytes = self + .decoder + .get(token) + .unwrap_or_else(|| &self.special_tokens_decoder[token]); + ret.extend(token_bytes); + } + ret + } + + fn _encode_ordinary_native(&self, text: &str) -> Vec { + // This is the core of the encoding logic; the other functions in here + // just make things complicated :-) + let regex = self._get_tl_regex(); + let mut ret = vec![]; + for mat in regex.find_iter(text) { + let piece = mat.unwrap().as_str().as_bytes(); + if let Some(token) = self.encoder.get(piece) { + ret.push(*token); + continue; + } + ret.extend(&byte_pair_encode(piece, &self.encoder)); + } + ret + } + + fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec, usize) { + let special_regex = self._get_tl_special_regex(); + let regex = self._get_tl_regex(); + let mut ret = vec![]; + + let mut start = 0; + let mut last_piece_token_len = 0; + loop { + let mut next_special; + let mut start_find = start; + loop { + // Find the next allowed special token, if any + next_special = special_regex.find_from_pos(text, start_find).unwrap(); + match next_special { + Some(m) => { + if allowed_special.contains(&text[m.start()..m.end()]) { + break; + } + start_find = m.start() + 1; + } + None => break, + } + } + let end = next_special.map_or(text.len(), |m| m.start()); + + // Okay, here we go, compare this logic to _encode_ordinary_native + for mat in regex.find_iter(&text[start..end]) { + let piece = mat.unwrap().as_str().as_bytes(); + if let Some(token) = self.encoder.get(piece) { + last_piece_token_len = 1; + ret.push(*token); + continue; + } + let tokens = byte_pair_encode(piece, &self.encoder); + last_piece_token_len = tokens.len(); + ret.extend(&tokens); + } + + match next_special { + // And here we push the special token + Some(m) => { + let piece = m.as_str(); + let token = self.special_tokens_encoder[piece]; + ret.push(token); + start = m.end(); + last_piece_token_len = 0; + } + None => break, + } + } + + // last_piece_token_len is how many tokens came from the last regex split. This is used + // for determining unstable tokens, since you can't merge across (stable) regex splits + (ret, last_piece_token_len) + } + + fn _increase_last_piece_token_len( + &self, + tokens: Vec, + mut last_piece_token_len: usize, + ) -> (Vec, usize) { + // Unfortunately, the locations where our regex splits can be unstable. + // For the purposes of determining unstable tokens, unstable regex splitting + // is only a problem if a split that was present disappears, since this can + // lead to merging of tokens otherwise thought to be stable. + // cl100k_base makes our life hard by including the \s*[\r\n]+ + // pattern. This can e.g. cause "\n" + " " to become "\n \n". + // Here is a quick and dirty fix: + { + let token_is_all_space = |token| { + self.decoder + .get(token) + .map(|token_bytes| { + token_bytes + .iter() + .rev() + .all(|&b| [b' ', b'\n', b'\t'].contains(&b)) + }) + .unwrap_or(false) + }; + if last_piece_token_len > 0 + && token_is_all_space(&tokens[tokens.len() - last_piece_token_len]) + { + while (last_piece_token_len < tokens.len()) + && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1]) + { + last_piece_token_len += 1; + } + } + } + debug_assert!(last_piece_token_len <= tokens.len()); + + (tokens, last_piece_token_len) + } + + fn _encode_unstable_native( + &self, + text: &str, + allowed_special: &HashSet<&str>, + ) -> (Vec, HashSet>) { + let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special); + if last_piece_token_len == 0 { + // If last_piece_token_len is zero, the last token was a special token and we have + // no unstable bytes + return (tokens, HashSet::new()); + } + let (mut tokens, last_piece_token_len) = + self._increase_last_piece_token_len(tokens, last_piece_token_len); + + let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); + tokens.truncate(tokens.len() - last_piece_token_len); + + // TODO: we should try harder to find additional stable tokens + // This would reduce the amount of retokenising when determining completions + // Refer to the logic in an older version of this file + + let mut completions = HashSet::new(); + if unstable_bytes.is_empty() { + return (tokens, completions); + } + + // This is the easy bit. Just find all single tokens that start with unstable_bytes + // (including tokens that exactly match unstable_bytes) + // Separating this from the loop below helps with performance in a common case. + let mut point = self + .sorted_token_bytes + .partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); + while point < self.sorted_token_bytes.len() + && self.sorted_token_bytes[point].starts_with(&unstable_bytes) + { + completions.insert(vec![ + self.encoder[self.sorted_token_bytes[point].as_slice()], + ]); + point += 1; + } + + // Now apply even more brute force. At every (other) possible position for the straddling + // token, concatenate additional bytes from that token (if any) to unstable_bytes, + // and retokenise the whole thing and see what we get. + for i in 1..unstable_bytes.len() { + let prefix = &unstable_bytes[..i]; + let suffix = &unstable_bytes[i..]; + let mut point = self + .sorted_token_bytes + .partition_point(|x| x.as_slice() < suffix); + // TODO: Perf optimisation if suffix starts with " "? + while point < self.sorted_token_bytes.len() + && self.sorted_token_bytes[point].starts_with(suffix) + { + let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); + let encoded = match std::str::from_utf8(&possibility) { + // Morally, this is byte_pair_encode(&possibility, &self.encoder) + // But we might have introduced a regex split which would prevent merges. + // (particularly possible in the presence of unstable regex splits) + // So convert to UTF-8 and do regex splitting. + // E.g. with cl100k_base " !" gets split to " " + " !", + // but byte_pair_encode(" !") != byte_pair_encode(" ") + Ok(s) => self._encode_ordinary_native(s), + + // Technically, whether or not this arm is correct depends on whether there + // would be a regex split before the UTF-8 truncation point. + // Probably niche enough that no one will ever notice (after all, people didn't + // notice all the big holes in the previous unstable token implementation) + Err(_) => byte_pair_encode(&possibility, &self.encoder), + // Something like the following is intriguing but incorrect: + // Err(e) => self._encode_ordinary_native(unsafe { + // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) + // }), + }; + let mut seq = Vec::new(); + let mut seq_len = 0; + for token in encoded { + seq.push(token); + seq_len += self.decoder[&token].len(); + if seq_len >= unstable_bytes.len() { + break; + } + } + completions.insert(seq); + point += 1; + } + } + + // This is also not straightforward. While we generally assume that regex splits are stable, + // unfortunately, they are not. That is, if adding bytes were to make a split appear in + // unstable_bytes, this could make tokens possible which our logic would otherwise think + // would be merged. + // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could + // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token. + // Here is a quick and dirty fix: + // This isn't right if we ever remove \s+(?!\S) + if unstable_bytes.len() > 1 { + let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice()); + if unstable_bytes.len() - last_decoded.1 > 0 + && last_decoded.0.map_or(false, |c| c.is_whitespace()) + { + let mut reencoded = byte_pair_encode( + &unstable_bytes[..unstable_bytes.len() - last_decoded.1], + &self.encoder, + ); + reencoded.extend(byte_pair_encode( + &unstable_bytes[unstable_bytes.len() - last_decoded.1..], + &self.encoder, + )); + completions.insert(reencoded); + } + } + + (tokens, completions) + } +} + +#[pymethods] +impl CoreBPE { + #[new] + fn new( + encoder: HashMap, usize>, + special_tokens_encoder: HashMap, + pattern: &str, + ) -> PyResult { + let regex = Regex::new(pattern) + .map_err(|e| PyErr::new::(e.to_string()))?; + + let special_regex = { + let _parts = special_tokens_encoder + .keys() + .map(|s| fancy_regex::escape(s)) + .collect::>(); + Regex::new(&_parts.join("|")) + .map_err(|e| PyErr::new::(e.to_string()))? + }; + + let decoder: HashMap> = + encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); + + assert!( + encoder.len() == decoder.len(), + "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?" + ); + + let special_tokens_decoder: HashMap> = special_tokens_encoder + .iter() + .map(|(k, v)| (*v, k.as_bytes().to_vec())) + .collect(); + + // Clone because I don't know how to tell Rust I'm not going to change the map + let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); + sorted_token_bytes.sort(); + + Ok(CoreBPE { + encoder, + special_tokens_encoder, + decoder, + special_tokens_decoder, + regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), + special_regex_tls: (0..MAX_NUM_THREADS) + .map(|_| special_regex.clone()) + .collect(), + sorted_token_bytes, + }) + } + + // ==================== + // Encoding + // ==================== + + fn encode_ordinary(&self, py: Python, text: &str) -> Vec { + py.allow_threads(|| self._encode_ordinary_native(text)) + } + + fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec { + py.allow_threads(|| self._encode_native(text, &allowed_special).0) + } + + fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec { + py.allow_threads(|| { + match std::str::from_utf8(bytes) { + Ok(text) => self._encode_ordinary_native(text), + Err(e) => { + let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; + let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new()); + let (mut tokens, last_piece_token_len) = + self._increase_last_piece_token_len(tokens, last_piece_token_len); + if !tokens.is_empty() && last_piece_token_len > 0 { + // Lop off the tokens from the last piece and run BPE on the remaining bytes + // Somewhat niche, but this may not be correct if we'd have had a regex + // split between the valid UTF-8 and the invalid bytes, which is why this + // method is private + let mut unstable_bytes = + self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); + unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); + + tokens.truncate(tokens.len() - last_piece_token_len); + tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder)); + } + tokens + } + } + }) + } + + fn encode_with_unstable( + &self, + py: Python, + text: &str, + allowed_special: HashSet<&str>, + ) -> Py { + let (tokens, completions) = + py.allow_threads(|| self._encode_unstable_native(text, &allowed_special)); + let py_completions = + PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..]))); + (tokens, py_completions).into_py(py) + } + + fn encode_single_token(&self, piece: &[u8]) -> PyResult { + if let Some(token) = self.encoder.get(piece).copied() { + return Ok(token); + } + if let Ok(piece_str) = std::str::from_utf8(piece) { + if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { + return Ok(token); + } + } + Err(PyErr::new::(piece.to_owned())) + } + + fn encode_single_piece(&self, piece: &[u8]) -> Vec { + if let Some(token) = self.encoder.get(piece) { + return vec![*token]; + } + byte_pair_encode(piece, &self.encoder) + } + + // ==================== + // Decoding + // ==================== + + fn decode_bytes(&self, py: Python, tokens: Vec) -> Py { + let bytes = py.allow_threads(|| self._decode_native(&tokens)); + PyBytes::new(py, &bytes).into() + } + + fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult> { + if let Some(bytes) = self.decoder.get(&token) { + return Ok(PyBytes::new(py, bytes).into()); + } + if let Some(bytes) = self.special_tokens_decoder.get(&token) { + return Ok(PyBytes::new(py, bytes).into()); + } + Err(PyErr::new::(token.to_string())) + } + + // ==================== + // Miscellaneous + // ==================== + + fn token_byte_values(&self, py: Python) -> Vec> { + self.sorted_token_bytes + .iter() + .map(|x| PyBytes::new(py, x).into()) + .collect() + } +} + +#[pymodule] +fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use rustc_hash::FxHashMap as HashMap; + + use crate::byte_pair_split; + + #[test] + fn very_simple_test() { + let mut ranks = HashMap::default(); + ranks.insert(b"ab".to_vec(), 1); + ranks.insert(b"cd".to_vec(), 2); + + let res = byte_pair_split(b"abcd", &ranks); + assert_eq!(res, vec![b"ab", b"cd"]); + } +}