From ac46243c2cfe0920253e8446d22e57899fde3188 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 3 Jan 2025 14:21:41 +0100 Subject: [PATCH 01/11] initial-commit --- tokenizers/Cargo.toml | 2 + .../src/models/backtracking_bpe/bitfield.rs | 57 ++ tokenizers/src/models/backtracking_bpe/mod.rs | 2 + .../src/models/backtracking_bpe/model.rs | 816 ++++++++++++++++++ tokenizers/src/models/mod.rs | 1 + 5 files changed, 878 insertions(+) create mode 100644 tokenizers/src/models/backtracking_bpe/bitfield.rs create mode 100644 tokenizers/src/models/backtracking_bpe/mod.rs create mode 100644 tokenizers/src/models/backtracking_bpe/model.rs diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 16285e113..ddbeba4d1 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -68,6 +68,8 @@ fancy-regex = { version = "0.13", optional = true} getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" +fnv = "1.0.7" +aneubeck-daachorse = "1.1.1" [features] default = ["progressbar", "onig", "esaxx_fast"] diff --git a/tokenizers/src/models/backtracking_bpe/bitfield.rs b/tokenizers/src/models/backtracking_bpe/bitfield.rs new file mode 100644 index 000000000..832965931 --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/bitfield.rs @@ -0,0 +1,57 @@ +/// Small helper to manage a bit field which supports predecessor and successor queries with a simple scan implementation. +/// This is sufficient for our use case, since two one bits will be at most 128 bits apart. +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct BitField { + bitfield: Vec, +} + +impl BitField { + /// All bits are initialized to 1. + pub(crate) fn new(bits: usize) -> Self { + Self { + bitfield: vec![u64::MAX; (bits + 63) / 64], + } + } + + pub(crate) fn is_set(&self, bit: usize) -> bool { + let (word, bit) = (bit / 64, bit % 64); + self.bitfield[word] & (1 << bit) != 0 + } + + pub(crate) fn clear(&mut self, bit: usize) { + let (word, bit) = (bit / 64, bit % 64); + self.bitfield[word] &= !(1 << bit); + } + + pub(crate) fn successor(&self, bit: usize) -> usize { + let (mut word_idx, bit_idx) = (bit / 64, bit % 64); + let word = self.bitfield[word_idx] >> bit_idx; + if word != 0 { + word.trailing_zeros() as usize + bit + } else { + loop { + word_idx += 1; + let word = self.bitfield[word_idx]; + if word != 0 { + break word.trailing_zeros() as usize + word_idx * 64; + } + } + } + } + + pub(crate) fn predecessor(&self, bit: usize) -> usize { + let (mut word_idx, bit_idx) = (bit / 64, bit % 64); + let word = self.bitfield[word_idx] << (63 - bit_idx); + if word != 0 { + bit - word.leading_zeros() as usize + } else { + loop { + word_idx -= 1; + let word = self.bitfield[word_idx]; + if word != 0 { + break word_idx * 64 + 63 - word.leading_zeros() as usize; + } + } + } + } +} diff --git a/tokenizers/src/models/backtracking_bpe/mod.rs b/tokenizers/src/models/backtracking_bpe/mod.rs new file mode 100644 index 000000000..2e52b478d --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/mod.rs @@ -0,0 +1,2 @@ +mod bitfield; +mod model; diff --git a/tokenizers/src/models/backtracking_bpe/model.rs b/tokenizers/src/models/backtracking_bpe/model.rs new file mode 100644 index 000000000..1afd57dfa --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/model.rs @@ -0,0 +1,816 @@ +use super::bitfield::BitField; +use super::{super::bpe::trainer::BpeTrainer, super::bpe::Error, super::OrderedVocabIter}; +use crate::models::bpe::BPE; +use crate::tokenizer::{Model, Result, Token}; +use crate::utils::iter::ResultShunt; +use aneubeck_daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder}; +use fnv::{FnvHashMap, FnvHasher}; +use itertools::Itertools; +use serde::de::Visitor; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_json::Value; +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::hash::{Hash, Hasher}; +use std::ops::Range; +use std::{ + collections::HashMap, + fs::File, + io::prelude::*, + io::{BufRead, BufReader}, + path::{Path, PathBuf}, +}; +pub type Vocab = HashMap; +type VocabR = HashMap; +pub type Merges = Vec<(String, String)>; + +/// This can be thought of as a lazy variation of the dynamic programming approach. +/// It only computes those states which have to be visited in order to compute the tokenization +/// for a given input text. +/// It keeps track of visited states in a bitfield and only remembers the tokenization +/// of the currently processed dynamic programming state. +/// +/// The biggest downside of this approach is that the search for the longest leftmost match (the firt token?) +/// has to be reset at every (backtracking) step which is still a net win in practice compared to other approaches. +#[derive(Clone, PartialEq)] +pub(crate) struct BacktrackState<'a> { + text: &'a [u8], + tokens: Vec, // len of the tezt / 3 + next_token: Option, // bpe.next_match(text) wich is longest_searcher.leftmost_find_iter(text)'s first match value + pos: usize, // current pos in the text? + bitfield: BitField, // keeps track of token boundaries? keeps track of all the valid tokenization positions and making the runtime linear in the input length. +} + +impl<'a> BacktrackState<'a> { + pub(crate) fn new(text: &'a [u8], next_token: Option) -> Self { + Self::with_capacity(text, next_token, text.len() / 3) + } + + pub(crate) fn with_capacity(text: &'a [u8], next_token: Option, cap: usize) -> Self { + Self { + text, + tokens: Vec::with_capacity(cap), + next_token, + pos: 0, + bitfield: BitField::new(text.len() + 1), + } + } + pub(crate) fn count(&self) -> usize { + self.tokens.len() + } + + pub(crate) fn pos(&self) -> usize { + self.pos + } + + pub(crate) fn last_token(&self) -> Option { + self.tokens.last().copied() + } + + pub(crate) fn into_tokens(self) -> Vec { + self.tokens + } +} + +struct Config { + files: Option<(String, String)>, + vocab: Vocab, + merges: Merges, + dropout: Option, + unk_token: Option, + fuse_unk: bool, + byte_fallback: bool, +} + +pub struct BacktrackingBpeBuilder { + config: Config, +} + +impl Default for BacktrackingBpeBuilder { + fn default() -> Self { + Self { + config: Config { + files: None, + vocab: HashMap::new(), + merges: vec![], + dropout: None, + unk_token: None, + fuse_unk: false, + byte_fallback: false, + }, + } + } +} + +/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. +#[derive(Serialize, PartialEq)] +pub struct BacktrackingBpe { + /// All the decoded tokens concatenated into? used to build the aho corasick searchers + all_tokens: Vec, + /// Start index of each token in all_tokens. + /// The end is simply the next entry in this vector. + token_starts: Vec, + /// Mapping from hash of token to token id. + bytes_hash_to_token: FnvHashMap, + /// The two tokens from which the token got merged. + /// If the token is an original one, than the two tokens point back to itself. + split_table: Vec<(u32, u32)>, + /// Mapping from a pair of tokens to a merged token if such a merged token exists. + pair_lookup: FnvHashMap<(u32, u32), u32>, + /// An aho corasick automaton to find the next longest token in a byte sequence. + // #[serde( + // serialize_with = "serialize_daac", + // deserialize_with = "deserialize_daac" + // )] + #[serde(skip)] + longest_searcher: DoubleArrayAhoCorasick, + /// An aho corasick automaton to find ALL tokens in a byte sequence. + // #[serde( + // serialize_with = "serialize_daac", + // deserialize_with = "deserialize_daac" + // )] + #[serde(skip)] + pub(crate) overlapping_searcher: DoubleArrayAhoCorasick, + /// An aho corasick automaton to find ALL tokens in a byte sequence which is being processed in reverse order. + // #[serde( + // serialize_with = "serialize_daac", + // deserialize_with = "deserialize_daac" + // )] + #[serde(skip)] + pub(crate) overlapping_searcher_rev: DoubleArrayAhoCorasick, + /// Mapping from a token to the next longest prefix token. + /// This is in principle information represented by the AhoCorasick automaton. + /// But we don't have efficient access to it and therefore store it here again. + /// If there is none, then the value is set to u32::MAX. + next_prefix_match: Vec, + /// Hash factor used to prevent hash collisions. + hash_factor: u64, + vocab: Vocab, + vocab_r: VocabR, + unk_token: Option, +} + +impl BacktrackingBpeBuilder { + /// Constructs a new `BacktrackingBpeBuilder`. + pub fn new() -> Self { + Self::default() + } + + /// Set the input files. + #[must_use] + pub fn files(mut self, vocab: String, merges: String) -> Self { + self.config.files = Some((vocab, merges)); + self + } + + /// Set the vocab (token -> ID) and merges mappings. + #[must_use] + pub fn vocab_and_merges(mut self, vocab: Vocab, merges: Merges) -> Self { + self.config.vocab = vocab; + self.config.merges = merges; + self + } + + /// Use [dropout](https://arxiv.org/abs/1910.13267) with the model. + #[must_use] + pub fn dropout(mut self, dropout: f32) -> Self { + self.config.dropout = Some(dropout); + self + } + + /// Set the `UNK` token for the vocab. + #[must_use] + pub fn unk_token(mut self, unk_token: String) -> Self { + self.config.unk_token = Some(unk_token); + self + } + + /// Set the `fuse_unk` option. + #[must_use] + pub fn fuse_unk(mut self, fuse_unk: bool) -> Self { + self.config.fuse_unk = fuse_unk; + self + } + + /// Set the `byte_fallback` option. + #[must_use] + pub fn byte_fallback(mut self, byte_fallback: bool) -> Self { + self.config.byte_fallback = byte_fallback; + self + } + + /// Returns a `BacktrackingBpe` model that uses the `BacktrackingBpeBuilder`'s configuration. + pub fn build(mut self) -> Result { + // Validate dropout. + if let Some(p) = self.config.dropout { + if !(0.0..=1.0).contains(&p) { + return Err(Error::InvalidDropout.into()); + } + } + + // Read files if necessary + if let Some((vocab, merges)) = self.config.files { + let (v, m) = BPE::read_file(&vocab, &merges)?; + self.config.vocab = v; + self.config.merges = m; + } + + let backtraching_bpe = BacktrackingBpe::from_dictionary( + self.config.vocab.into_iter().map(|(k, v)| k.into_bytes()), + Some(self.config.merges), + None, + ); + Ok(backtraching_bpe) + } +} + +impl Default for BacktrackingBpe { + fn default() -> Self { + Self::builder().build().unwrap() + } +} + +fn token_iter<'a>(all_tokens: &'a [u8], token_starts: &'a [u32]) -> impl Iterator { + token_starts + .iter() + .tuple_windows() + .map( move |(start, end)| &all_tokens[*start as usize..*end as usize]) +} + +fn next_match(longest_searcher: &DoubleArrayAhoCorasick, text: &[u8]) -> Option { + longest_searcher + .leftmost_find_iter(text) + .map(|m| m.value()) + .next() +} + +fn is_valid_token_pair( + pair_lookup: &FnvHashMap<(u32, u32), u32>, + split_table: &[(u32, u32)], + mut token1: u32, + mut token2: u32, +) -> bool { + // Keep track of the maximum token which can still be chosen across the split point. + let mut limit = u32::MAX; + loop { + // Check whether BPE would choose a different token pair across the split point. + if let Some(combined) = pair_lookup.get(&(token1, token2)) { + if *combined < limit { + return false; + } + } + // Reverse the merge operation from BPE. + if token1 > token2 { + limit = token1; + token1 = unsafe { split_table.get_unchecked(token1 as usize).1 }; + if token1 == limit { + limit = token2 + 1; + token2 = unsafe { split_table.get_unchecked(token2 as usize).0 }; + if token2 + 1 == limit { + return true; + } + } + } else { + limit = token2 + 1; + token2 = unsafe { split_table.get_unchecked(token2 as usize).0 }; + if token2 + 1 == limit { + limit = token1; + token1 = unsafe { split_table.get_unchecked(token1 as usize).1 }; + if token1 == limit { + return true; + } + } + } + } +} + +fn token_range(token_starts: &[u32], token_id: u32) -> Range { + unsafe { + *token_starts.get_unchecked(token_id as usize) as usize + ..*token_starts.get_unchecked(token_id as usize + 1) as usize + } +} + +fn token_bytes<'a>(all_tokens: &'a [u8], token_starts: &[u32], token_id: u32) -> &'a [u8] { + &all_tokens[token_range(token_starts, token_id)] +} + +fn hash_bytes(bytes: &[u8], factor: u64) -> u32 { + let mut hasher = FnvHasher::default(); + bytes.hash(&mut hasher); + // Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash. + // To make them unique for the given tokens, we have to add unfortunately another multiplication. + ((hasher.finish().wrapping_mul(factor)) >> 32) as u32 +} +fn find_token_by_bytes( + all_tokens: &[u8], + token_starts: &[u32], + bytes_hash_to_token: &FnvHashMap, + bytes: &[u8], + hash_factor: u64, +) -> Option { + let hash = hash_bytes(bytes, hash_factor); + let token = *bytes_hash_to_token.get(&hash)?; + if token_bytes(all_tokens, token_starts, token) == bytes { + Some(token) + } else { + None + } +} + +/// Converts the merges strings (for example from `merges.txt` file) with the format +/// "{pair_a} {pair_b}" into the format expected by the BacktrackingBpe struct +pub(crate) fn convert_merges_to_hashmap>( + iter: I, + _vocab: &Vocab, +) -> Result { + let mut merges = vec![]; + + let lines = iter.filter(|l| !l.starts_with("#version")); + for (rank, line) in lines.enumerate() { + let parts = line.split(' ').collect::>(); + if parts.len() != 2 { + return Err(Error::BadMerges(rank + 1).into()); + } + + merges.push((parts[0].to_string(), parts[1].to_string())); + } + + Ok(merges) +} + +impl BacktrackingBpe { + /// Initialize a `BacktrackingBpeBuilder`. + pub fn builder() -> BacktrackingBpeBuilder { + BacktrackingBpeBuilder::new() + } + + /// Create a new BacktrackingBpe model with the given vocab and merges. + pub fn new(vocab: Vocab, merges: Merges) -> Self { + Self::builder() + .vocab_and_merges(vocab, merges) + .build() + .unwrap() + } + + fn bitfield_into_tokens(&self, bytes: &[u8], bitfield: BitField, count: usize) -> Vec { + let mut encoded = Vec::with_capacity(count); + let mut start = 0; + while start < bytes.len() { + let end = bitfield.successor(start + 1); + let token = self.find_token_by_bytes(&bytes[start..end]).expect(""); + encoded.push(token); + start = end; + } + encoded + } + + fn encode_into_bitfield(&self, bytes: &[u8]) -> (BitField, usize) { + // Reserve for every byte a bit in the bitfield. + let mut bitfield = BitField::new(bytes.len() + 1); + let mut heap = BinaryHeap::with_capacity(bytes.len() * 2); + heap.extend((0..bytes.len().saturating_sub(1)).filter_map(|i| { + self.find_token_by_bytes(&bytes[i..i + 2]) + .map(|e| Reverse((e, i as u32))) + })); + let mut count = bytes.len(); + while let Some(Reverse((token, start))) = heap.pop() { + let start = start as usize; + if !bitfield.is_set(start) { + continue; + } + let mid = bitfield.successor(start + 1); + if mid >= bytes.len() { + continue; + } + let end = bitfield.successor(mid + 1); + if self.token_len(token) != end - start { + continue; + } + bitfield.clear(mid); + count -= 1; + if end < bytes.len() { + let new_end = bitfield.successor(end + 1); + if let Some(e) = self.find_token_by_bytes(&bytes[start..new_end]) { + heap.push(Reverse((e, start as u32))); + } + } + if start > 0 { + let new_start = bitfield.predecessor(start - 1); + if let Some(e) = self.find_token_by_bytes(&bytes[new_start..end]) { + heap.push(Reverse((e, new_start as u32))); + } + } + } + (bitfield, count) + } + + pub fn encode_via_bitfield(&self, text: &[u8]) -> Vec { + let (bitfield, count) = self.encode_into_bitfield(text); + self.bitfield_into_tokens(text, bitfield, count) + } + + /// Construct a BytePairEncoding instance from an iterator that enumerates all tokens. + /// A suitable hash factor may be necessary to prevent hash collisions, which can be + /// found using [`find_hash_factor_for_dictionary`]. + /// + /// The recommended approach is to store the serialized value and reuse that, + /// to prevent repeating the cost of computing the hash factor and encoding. + pub fn from_dictionary( + tokens: impl IntoIterator>, + merges: Option>, + hash_factor: Option, + ) -> Self { + let hash_factor = hash_factor + .inspect(|f| assert_ne!(*f, 0, "hash factor must be larger than zero")) + .unwrap_or(1); + let mut all_tokens = Vec::new(); + let mut all_tokens_rev = Vec::new(); + let mut token_starts = vec![0]; + let mut bytes_hash_to_token = FnvHashMap::default(); + for (i, token) in tokens.into_iter().enumerate() { + bytes_hash_to_token.insert(hash_bytes(&token, hash_factor), i as u32); + all_tokens_rev.extend(token.iter().copied().rev()); + all_tokens.extend(token); + token_starts.push(all_tokens.len() as u32); + } + assert_eq!(bytes_hash_to_token.len() + 1, token_starts.len()); + let longest_searcher = DoubleArrayAhoCorasickBuilder::new() + .match_kind(aneubeck_daachorse::MatchKind::LeftmostLongest) + .build(token_iter(&all_tokens, &token_starts)) + .expect("failed to build AhoCorasick"); + + let overlapping_searcher = + DoubleArrayAhoCorasick::::new(token_iter(&all_tokens, &token_starts)).expect(""); + let overlapping_searcher_rev = + DoubleArrayAhoCorasick::::new(token_iter(&all_tokens, &token_starts)).expect(""); + + let next_prefix_match: Vec<_> = token_iter(&all_tokens, &token_starts) + .map(|token| { + next_match(&longest_searcher, &token[0..token.len() - 1]).unwrap_or(u32::MAX) + }) + .collect(); + + let mut split_table = vec![]; + let mut pair_lookup = FnvHashMap::default(); + + // Reverse engineer the merge/split table. + for (id, token) in token_iter(&all_tokens, &token_starts).enumerate() { + let mut token1 = next_prefix_match[id]; + while token1 != u32::MAX { + let rest = &token[token_range(&token_starts, token1).len()..]; + if let Some(token2) = find_token_by_bytes( + &all_tokens, + &token_starts, + &bytes_hash_to_token, + rest, + hash_factor, + ) { + if token1 < id as u32 + && token2 < id as u32 + && is_valid_token_pair(&pair_lookup, &split_table, token1, token2) + { + pair_lookup.insert((token1, token2), id as u32); + split_table.push((token1, token2)); + break; + } + } + token1 = next_prefix_match[token1 as usize]; + } + if token1 == u32::MAX { + split_table.push((id as u32, id as u32)); + } + } + let vocab: HashMap = token_iter(&all_tokens, &token_starts) + .enumerate() + .map(|(id, item)| (unsafe { String::from_utf8_unchecked(Vec::from(item.clone())) }, id as u32)) + .collect(); + + let vocab_r: HashMap = token_iter(&all_tokens, &token_starts) + .enumerate() + .map(|(id, item)| (id as u32, unsafe { String::from_utf8_unchecked(Vec::from(item.clone())) })) + .collect(); + + let bpe = Self { + all_tokens, + token_starts, + bytes_hash_to_token, + overlapping_searcher, + overlapping_searcher_rev, + longest_searcher, + next_prefix_match, + pair_lookup, + split_table, + hash_factor, + unk_token: None, + vocab, + vocab_r, + }; + for token_id in 0..bpe.num_tokens() as u32 { + let bytes = bpe.token_bytes(token_id); + let tokens = bpe.encode_via_bitfield(bytes); + assert_eq!( + tokens, + vec![token_id], + "token {token_id} with bytes {bytes:?} encodes to {tokens:?} instead of to itself" + ); + } + bpe + } + + /// Initialize a BacktrackingBpeBuilder model from vocab and merges files + pub fn from_file(vocab: &str, merges: &str) -> BacktrackingBpeBuilder { + Self::builder().files(vocab.to_owned(), merges.to_owned()) + } + + /// Read the given files to extract the vocab and merges + pub fn read_file(vocab: &str, merges: &str) -> Result<(Vocab, Merges)> { + // Read vocab.json + let vocab_file = File::open(vocab)?; + let mut vocab_file = BufReader::new(vocab_file); + + let mut buffer = String::new(); + vocab_file.read_to_string(&mut buffer)?; + let json: Value = serde_json::from_str(&buffer)?; + let mut vocab = HashMap::new(); + match json { + Value::Object(m) => { + for (token, id) in m { + if let Value::Number(id) = id { + let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32; + vocab.insert(token, id); + } + } + } + _ => return Err(Box::new(Error::BadVocabulary)), + }; + + // Read merges file + let merge_file = File::open(merges)?; + let merge_file = BufReader::new(merge_file); + let merges = ResultShunt::process(merge_file.lines(), |iter| { + convert_merges_to_hashmap(iter, &vocab) + })??; // TODO correctly process to fill the split and pair lookup + + Ok((vocab, merges)) + } + + /// Return the number of tokens in this BPE dictionary. + pub fn num_tokens(&self) -> usize { + self.token_starts.len() - 1 + } + + /// Converts a token id into its corresponding token bytes. + /// Panics if the token_id is not within the valid 0..num_tokens() range! + pub fn token_bytes(&self, token_id: u32) -> &[u8] { + token_bytes(&self.all_tokens, &self.token_starts, token_id) + } + + pub(crate) fn is_valid_token_pair(&self, token1: u32, token2: u32) -> bool { + is_valid_token_pair(&self.pair_lookup, &self.split_table, token1, token2) + } + + /// Returns the length of the decoded byte slice of a token. + pub fn token_len(&self, token_id: u32) -> usize { + token_range(&self.token_starts, token_id).len() + } + + /// Returns the first longest match in the provided text. + pub(crate) fn next_match(&self, text: &[u8]) -> Option { + next_match(&self.longest_searcher, text) + } + + /// Returns the next token which shares the longest prefix with the specified token. + pub(crate) fn next_prefix(&self, token_id: u32) -> Option { + let prefix = self.next_prefix_match[token_id as usize]; + if prefix == u32::MAX { + None + } else { + Some(prefix) + } + } + + fn find_token_by_bytes(&self, bytes: &[u8]) -> Option { + find_token_by_bytes( + &self.all_tokens, + &self.token_starts, + &self.bytes_hash_to_token, + bytes, + self.hash_factor, + ) + } + + /// Decode a sequence of tokens back to its original byte sequence. + /// Note: we don't return here a str, since not every token sequence corresponds to a valid + /// utf8 sequence. + pub fn decode_tokens(&self, tokens: &[u32]) -> Vec { + let mut text = vec![]; + for token in tokens { + text.extend(self.token_bytes(*token)); + } + text + } + + /// Computes for every prefix of the input text a corresponding last token. + pub(crate) fn encode_all_prefixes(&self, text: &[u8]) -> Vec { + let mut last_token = Vec::with_capacity(text.len()); + let mut state = self.overlapping_searcher.start_state(); + for (pos, c) in text.iter().enumerate() { + let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c); + state = s; + for m in iter { + let new_token = m.value(); + let new_range = m.start()..m.end(); + assert_eq!(new_range.end, last_token.len() + 1); + if new_range.start == 0 { + last_token.push(new_token); + break; + } else { + let prev_token = unsafe { *last_token.get_unchecked(new_range.start - 1) }; + if self.is_valid_token_pair(prev_token, new_token) { + last_token.push(new_token); + break; + } + } + } + } + last_token + } + + /// Counts the number tokens produced when encoding the text. + pub fn count(&mut self, text: &[u8]) -> usize { + let mut enc = BacktrackState::new(text, None); + while self.step(&mut enc).is_some() {} + enc.count() + } + + pub fn encode_via_table(&self, text: &[u8]) -> Vec { + let last_token = self.encode_all_prefixes(text); + let mut encoded = Vec::with_capacity(text.len() / 3); + let mut pos = text.len(); + while pos > 0 { + let token = last_token[pos - 1]; + encoded.push(token); + pos -= self.token_len(token); + } + encoded.reverse(); + encoded + } + + pub fn encode_via_backtracking(&self, text: &[u8]) -> Vec { + let mut enc = BacktrackState::new(text, None); + while self.step(&mut enc).is_some() {} + println!("_______________________________"); + enc.into_tokens() + } + + pub fn get_vocab(&self) -> Vocab { + self.vocab.clone() + } + + pub fn get_unk_token(&self) -> &Option { + &self.unk_token + } + + pub fn step(&self, backtrack_state: &mut BacktrackState) -> Option { + let mut token = backtrack_state.next_token?; + let last = backtrack_state.tokens.last().copied(); + loop { + let token_len = self.token_len(token); + let end_pos = backtrack_state.pos + token_len; + if backtrack_state.bitfield.is_set(end_pos) + && last + .map(|last_token| self.is_valid_token_pair(last_token, token)) + .unwrap_or(true) + { + backtrack_state.tokens.push(token); + backtrack_state.pos = end_pos; + // In principle, we could in some cases reuse the leftmost longest match iterator. + // Especially when it has to look ahead, this could save scanning the input multiple times. + // But on average this seems to be slower due to the overhead of storing the iterator as part of the struct. + backtrack_state.next_token = self.next_match(&backtrack_state.text[end_pos..]); + break; + } else if let Some(shorter) = self.next_prefix(token) { + token = shorter; + } else { + // Clearing the bitfield when we pop tokens saves a little bit of work... + backtrack_state.bitfield.clear(backtrack_state.pos); + backtrack_state.tokens.pop(); + backtrack_state.pos -= last.map(|t| self.token_len(t)).unwrap_or(0); + backtrack_state.next_token = last; + break; + } + } + backtrack_state.next_token + } + + fn word_to_tokens<'a, 'b: 'a>( + &'a self, + word: &'b Vec, + ) -> impl Iterator + 'a { + word.into_iter() + .map(move |id| Token::new(*id, self.vocab_r[&id].clone(), (0usize, 0usize))) + } +} +impl Model for BacktrackingBpe { + type Trainer = BpeTrainer; + + fn get_vocab(&self) -> HashMap { + self.vocab.clone() + } + + fn get_vocab_size(&self) -> usize { + self.vocab.len() + } + + fn tokenize(&self, sequence: &str) -> Result> { + if sequence.is_empty() { + return Ok(vec![]); + } + let byte_text = sequence.as_bytes(); + let word = self.encode_via_backtracking(byte_text); + Ok(self.word_to_tokens(&word).collect()) + } + + fn token_to_id(&self, token: &str) -> Option { + self.vocab.get(token).copied() + } + + fn id_to_token(&self, id: u32) -> Option { + Some(self.vocab_r[&id].clone()) + } + + fn save(&self, folder: &Path, name: Option<&str>) -> Result> { + let vocab_file_name = match name { + Some(name) => format!("{name}-vocab.json"), + None => "vocab.json".to_string(), + }; + + // Write vocab.json + let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())] + .iter() + .collect(); + let mut vocab_file = File::create(&vocab_path)?; + let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r); + let serialized = serde_json::to_string(&order_vocab_iter)?; + vocab_file.write_all(serialized.as_bytes())?; + // + // // Write merges.txt + // let merges_file_name = match name { + // Some(name) => format!("{name}-merges.txt"), + // None => "merges.txt".to_string(), + // }; + // + // let merges_path: PathBuf = [folder, Path::new(merges_file_name.as_str())] + // .iter() + // .collect(); + // let mut merges_file = File::create(&merges_path)?; + // let mut merges: Vec<(&Vec<&str, &str>, &u32)> = self + // .merges + // .iter() + // .map(|(pair, (rank, _))| (pair, rank)) + // .collect(); + // merges.sort_unstable_by_key(|k| *k.1); + // merges_file.write_all(b"#version: 0.2\n")?; + // merges_file.write_all( + // &merges + // .into_iter() + // .flat_map(|(pair, _)| { + // format!("{} {}\n", self.vocab_r[&pair.0], self.vocab_r[&pair.1]).into_bytes() + // }) + // .collect::>()[..], + // )?; + Ok(vec![vocab_path]) + // Ok(vec![vocab_path, merges_path]) + } + + fn get_trainer(&self) -> BpeTrainer { + BpeTrainer::default() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::NamedTempFile; + + #[test] + fn my_example() { + let tokens = [ + "a", "b", "c", // 1 character each + "aac", "ac", "cc", "cca", "aacc", "aaccca", "acca", "acc", "aa", "aaa", + "aaaa", // 2 characters each + ]; + let mut bpe = BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None); + // bpe.encode_via_backtracking(b"baacca"); + bpe.encode_via_backtracking(b"aaaacc"); + + bpe.encode_via_backtracking(b"baaaaccca"); + let tokens = [ + "a", "b", "c", // 1 character each + "acca", "cc", "ac", "aac", "cca", + ]; + let mut bpe = BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None); + bpe.encode_via_backtracking(b"baacca"); + } +} diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 3ab3b495b..55d1e15eb 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -1,5 +1,6 @@ //! Popular tokenizer models. +pub mod backtracking_bpe; pub mod bpe; pub mod unigram; pub mod wordlevel; From 3e884ac703778161d3825fdfb0fb77ec68934f81 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 3 Jan 2025 14:34:26 +0100 Subject: [PATCH 02/11] update test --- .../src/models/backtracking_bpe/model.rs | 70 ++++++++++++++++--- 1 file changed, 62 insertions(+), 8 deletions(-) diff --git a/tokenizers/src/models/backtracking_bpe/model.rs b/tokenizers/src/models/backtracking_bpe/model.rs index 1afd57dfa..b317121b9 100644 --- a/tokenizers/src/models/backtracking_bpe/model.rs +++ b/tokenizers/src/models/backtracking_bpe/model.rs @@ -234,7 +234,7 @@ fn token_iter<'a>(all_tokens: &'a [u8], token_starts: &'a [u32]) -> impl Iterato token_starts .iter() .tuple_windows() - .map( move |(start, end)| &all_tokens[*start as usize..*end as usize]) + .map(move |(start, end)| &all_tokens[*start as usize..*end as usize]) } fn next_match(longest_searcher: &DoubleArrayAhoCorasick, text: &[u8]) -> Option { @@ -483,12 +483,21 @@ impl BacktrackingBpe { } let vocab: HashMap = token_iter(&all_tokens, &token_starts) .enumerate() - .map(|(id, item)| (unsafe { String::from_utf8_unchecked(Vec::from(item.clone())) }, id as u32)) + .map(|(id, item)| { + ( + unsafe { String::from_utf8_unchecked(Vec::from(item.clone())) }, + id as u32, + ) + }) .collect(); let vocab_r: HashMap = token_iter(&all_tokens, &token_starts) .enumerate() - .map(|(id, item)| (id as u32, unsafe { String::from_utf8_unchecked(Vec::from(item.clone())) })) + .map(|(id, item)| { + (id as u32, unsafe { + String::from_utf8_unchecked(Vec::from(item.clone())) + }) + }) .collect(); let bpe = Self { @@ -658,7 +667,8 @@ impl BacktrackingBpe { } pub fn encode_via_backtracking(&self, text: &[u8]) -> Vec { - let mut enc = BacktrackState::new(text, None); + let next_token = self.next_match(text); + let mut enc = BacktrackState::new(text, next_token); while self.step(&mut enc).is_some() {} println!("_______________________________"); enc.into_tokens() @@ -801,16 +811,60 @@ mod tests { "aac", "ac", "cc", "cca", "aacc", "aaccca", "acca", "acc", "aa", "aaa", "aaaa", // 2 characters each ]; - let mut bpe = BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None); + let mut bpe = + BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None, None); // bpe.encode_via_backtracking(b"baacca"); - bpe.encode_via_backtracking(b"aaaacc"); - + let tokens = bpe.tokenize("aaaacc").unwrap(); + println!("{:?}", bpe.tokenize("aaaacc")); + assert_eq!( + tokens, + vec![ + Token { + id: 12, + value: String::from("aaa"), + offsets: (0, 0) + }, + Token { + id: 10, + value: String::from("acc"), + offsets: (0, 0) + } + ] + ); + println!("{:?}", bpe.tokenize("baaaaccca")); + let tokens = bpe.tokenize("baaaaccca").unwrap(); + assert_eq!( + tokens, + vec![ + Token { + id: 1, + value: String::from("b"), + offsets: (0, 0) + }, + Token { + id: 12, + value: String::from("aaa"), + offsets: (0, 0) + }, + Token { + id: 4, + value: String::from("ac"), + offsets: (0, 0) + }, + Token { + id: 6, + value: String::from("cca"), + offsets: (0, 0) + } + ] + ); bpe.encode_via_backtracking(b"baaaaccca"); let tokens = [ "a", "b", "c", // 1 character each "acca", "cc", "ac", "aac", "cca", ]; - let mut bpe = BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None); + let mut bpe = + BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None, None); bpe.encode_via_backtracking(b"baacca"); } } From 0907a4d8c9def21685e8d4dae94bea1e1e330443 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 3 Jan 2025 14:59:30 +0100 Subject: [PATCH 03/11] update benches as well --- bindings/python/src/models.rs | 140 ++++++++++++++++++ tokenizers/benches/llama3.rs | 14 ++ tokenizers/src/models/backtracking_bpe/mod.rs | 2 + .../src/models/backtracking_bpe/model.rs | 14 +- tokenizers/src/models/mod.rs | 11 ++ 5 files changed, 180 insertions(+), 1 deletion(-) diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 0d5c0ddcd..863c69f18 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -36,6 +36,7 @@ impl PyModel { let base = self.clone(); Ok(match *self.model.as_ref().read().unwrap() { ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base))?.into_py(py), + ModelWrapper::BacktrackBpe(_) => Py::new(py, (PyBacktrackBPE {}, base))?.into_py(py), ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))?.into_py(py), ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base))?.into_py(py), ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base))?.into_py(py), @@ -560,6 +561,145 @@ impl PyBPE { } } +#[pyclass(module = "bpe")] +struct PyBacktrackBpe {} + +#[pymethods] +impl PyBacktrackBpe { + #[getter] + fn get_dropout(self_: PyRef) -> Option { + getter!(self_, BPE, dropout) + } + + #[setter] + fn set_dropout(self_: PyRef, dropout: Option) { + setter!(self_, BPE, dropout, dropout); + } + + #[getter] + fn get_unk_token(self_: PyRef) -> Option { + getter!(self_, BPE, unk_token.clone()) + } + + #[setter] + fn set_unk_token(self_: PyRef, unk_token: Option) { + setter!(self_, BPE, unk_token, unk_token); + } + #[new] + #[pyo3( + signature = (vocab=None, merges=None, **kwargs), + text_signature = "(self, vocab=None, merges=None, dropout=None, unk_token=None)")] + fn new( + py: Python<'_>, + vocab: Option, + merges: Option, + kwargs: Option<&Bound<'_, PyDict>>, + ) -> PyResult<(Self, PyModel)> { + if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) { + return Err(exceptions::PyValueError::new_err( + "`vocab` and `merges` must be both specified", + )); + } + + let mut builder = BPE::builder(); + if let (Some(vocab), Some(merges)) = (vocab, merges) { + match (vocab, merges) { + (PyVocab::Vocab(vocab), PyMerges::Merges(merges)) => { + builder = builder.vocab_and_merges(vocab, merges); + } + (PyVocab::Filename(vocab_filename), PyMerges::Filename(merges_filename)) => { + deprecation_warning( + py, + "0.9.0", + "BPE.__init__ will not create from files anymore, try `BPE.from_file` instead", + )?; + builder = + builder.files(vocab_filename.to_string(), merges_filename.to_string()); + } + _ => { + return Err(exceptions::PyValueError::new_err( + "`vocab` and `merges` must be both be from memory or both filenames", + )); + } + } + } + + PyBPE::with_builder(builder, kwargs) + } + + /// Read a :obj:`vocab.json` and a :obj:`merges.txt` files + /// + /// This method provides a way to read and parse the content of these files, + /// returning the relevant data structures. If you want to instantiate some BPE models + /// from memory, this method gives you the expected input from the standard files. + /// + /// Args: + /// vocab (:obj:`str`): + /// The path to a :obj:`vocab.json` file + /// + /// merges (:obj:`str`): + /// The path to a :obj:`merges.txt` file + /// + /// Returns: + /// A :obj:`Tuple` with the vocab and the merges: + /// The vocabulary and merges loaded into memory + #[staticmethod] + #[pyo3(text_signature = "(self, vocab, merges)")] + fn read_file(vocab: &str, merges: &str) -> PyResult<(Vocab, Merges)> { + BPE::read_file(vocab, merges).map_err(|e| { + exceptions::PyException::new_err(format!( + "Error while reading vocab & merges files: {}", + e + )) + }) + } + + /// Instantiate a BPE model from the given files. + /// + /// This method is roughly equivalent to doing:: + /// + /// vocab, merges = BPE.read_file(vocab_filename, merges_filename) + /// bpe = BPE(vocab, merges) + /// + /// If you don't need to keep the :obj:`vocab, merges` values lying around, + /// this method is more optimized than manually calling + /// :meth:`~tokenizers.models.BPE.read_file` to initialize a :class:`~tokenizers.models.BPE` + /// + /// Args: + /// vocab (:obj:`str`): + /// The path to a :obj:`vocab.json` file + /// + /// merges (:obj:`str`): + /// The path to a :obj:`merges.txt` file + /// + /// Returns: + /// :class:`~tokenizers.models.BPE`: An instance of BPE loaded from these files + #[classmethod] + #[pyo3(signature = (vocab, merges, **kwargs))] + #[pyo3(text_signature = "(cls, vocab, merge, **kwargs)")] + fn from_file( + _cls: &Bound<'_, PyType>, + py: Python, + vocab: &str, + merges: &str, + kwargs: Option<&Bound<'_, PyDict>>, + ) -> PyResult> { + let (vocab, merges) = BPE::read_file(vocab, merges).map_err(|e| { + exceptions::PyException::new_err(format!("Error while reading BPE files: {}", e)) + })?; + Py::new( + py, + PyBPE::new( + py, + Some(PyVocab::Vocab(vocab)), + Some(PyMerges::Merges(merges)), + kwargs, + )?, + ) + } +} + + /// An implementation of the WordPiece algorithm /// /// Args: diff --git a/tokenizers/benches/llama3.rs b/tokenizers/benches/llama3.rs index 77af3bd63..f053cd16a 100644 --- a/tokenizers/benches/llama3.rs +++ b/tokenizers/benches/llama3.rs @@ -3,6 +3,7 @@ extern crate criterion; use criterion::{Criterion, Throughput}; use tokenizers::Tokenizer; +use tokenizers::models::backtracking_bpe; pub fn llama3(c: &mut Criterion) { let data = std::fs::read_to_string("data/big.txt").unwrap(); @@ -30,6 +31,19 @@ pub fn llama3(c: &mut Criterion) { .unwrap() }) }); + + group.bench_function("llama3-backtracking", |b| { + let vocab = None; + let model: backtracking_bpe::BacktrackingBpe = backtracking_bpe::BacktrackingBpe::from_dictionary(vocab, None, None); + let tokenizer = Tokenizer::new(model); + let data: Vec<_> = data.lines().collect(); + let add_special_tokens = false; + b.iter(|| { + tokenizer + .encode_batch(criterion::black_box(data.clone()), add_special_tokens) + .unwrap() + }) + }); group.finish(); } diff --git a/tokenizers/src/models/backtracking_bpe/mod.rs b/tokenizers/src/models/backtracking_bpe/mod.rs index 2e52b478d..080cdc3df 100644 --- a/tokenizers/src/models/backtracking_bpe/mod.rs +++ b/tokenizers/src/models/backtracking_bpe/mod.rs @@ -1,2 +1,4 @@ mod bitfield; mod model; + +pub use model::*; \ No newline at end of file diff --git a/tokenizers/src/models/backtracking_bpe/model.rs b/tokenizers/src/models/backtracking_bpe/model.rs index b317121b9..8e1cc3cf3 100644 --- a/tokenizers/src/models/backtracking_bpe/model.rs +++ b/tokenizers/src/models/backtracking_bpe/model.rs @@ -103,7 +103,7 @@ impl Default for BacktrackingBpeBuilder { } /// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. -#[derive(Serialize, PartialEq)] +#[derive(Serialize, PartialEq, Clone) ] pub struct BacktrackingBpe { /// All the decoded tokens concatenated into? used to build the aho corasick searchers all_tokens: Vec, @@ -150,6 +150,18 @@ pub struct BacktrackingBpe { unk_token: Option, } +use std::fmt; +// Manually implement the Debug trait to exclude the `cache` field +impl fmt::Debug for BacktrackingBpe { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("BacktrackingBpe") + .field("vocab", &self.vocab) + .field("vocab_r", &self.vocab_r) + // Skipping `cache` field here, it won't be included in debug output + .finish() + } +} + impl BacktrackingBpeBuilder { /// Constructs a new `BacktrackingBpeBuilder`. pub fn new() -> Self { diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 55d1e15eb..a55732e62 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -16,6 +16,7 @@ use crate::models::unigram::{Unigram, UnigramTrainer}; use crate::models::wordlevel::{WordLevel, WordLevelTrainer}; use crate::models::wordpiece::{WordPiece, WordPieceTrainer}; use crate::{AddedToken, Model, Result, Token, Trainer}; +use crate::models::backtracking_bpe::BacktrackingBpe; /// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order /// of token ID, smallest to largest. @@ -61,7 +62,9 @@ impl<'a> Serialize for OrderedVocabIter<'a> { #[derive(Serialize, Debug, PartialEq, Clone)] #[serde(untagged)] pub enum ModelWrapper { + BPE(BPE), + BacktrackingBpe(BacktrackingBpe), // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility // with the versions not including the "type"), since WordLevel is a subset of WordPiece WordPiece(WordPiece), @@ -140,6 +143,7 @@ impl_enum_from!(WordLevel, ModelWrapper, WordLevel); impl_enum_from!(WordPiece, ModelWrapper, WordPiece); impl_enum_from!(BPE, ModelWrapper, BPE); impl_enum_from!(Unigram, ModelWrapper, Unigram); +impl_enum_from!(BacktrackingBpe, ModelWrapper, BacktrackingBpe); impl Model for ModelWrapper { type Trainer = TrainerWrapper; @@ -150,6 +154,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.tokenize(tokens), Self::BPE(t) => t.tokenize(tokens), Self::Unigram(t) => t.tokenize(tokens), + Self::BacktrackingBpe(t) => t.tokenize(tokens), } } @@ -159,6 +164,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.token_to_id(token), Self::BPE(t) => t.token_to_id(token), Self::Unigram(t) => t.token_to_id(token), + Self::BacktrackingBpe(t) =>t.token_to_id(token), } } @@ -168,6 +174,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.id_to_token(id), Self::BPE(t) => t.id_to_token(id), Self::Unigram(t) => t.id_to_token(id), + Self::BacktrackingBpe(t) => t.id_to_token(id), } } @@ -177,6 +184,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.get_vocab(), Self::BPE(t) => t.get_vocab(), Self::Unigram(t) => t.get_vocab(), + Self::BacktrackingBpe(t) => t.get_vocab(), } } @@ -186,6 +194,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.get_vocab_size(), Self::BPE(t) => t.get_vocab_size(), Self::Unigram(t) => t.get_vocab_size(), + Self::BacktrackingBpe(t) => t.get_vocab_size(), } } @@ -195,6 +204,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.save(folder, name), Self::BPE(t) => t.save(folder, name), Self::Unigram(t) => t.save(folder, name), + Self::BacktrackingBpe(t) =>t.save(folder, name), } } @@ -204,6 +214,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.get_trainer().into(), Self::BPE(t) => t.get_trainer().into(), Self::Unigram(t) => t.get_trainer().into(), + Self::BacktrackingBpe(t) => t.get_trainer().into(), } } } From dca45684849f7aabbcf2e13822f8cfeeb5f842c9 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 3 Jan 2025 16:14:14 +0100 Subject: [PATCH 04/11] nits --- tokenizers/benches/llama3.rs | 42 ++++++++++++------- tokenizers/src/models/backtracking_bpe/mod.rs | 2 +- .../src/models/backtracking_bpe/model.rs | 6 +-- tokenizers/src/models/mod.rs | 17 ++++---- 4 files changed, 38 insertions(+), 29 deletions(-) diff --git a/tokenizers/benches/llama3.rs b/tokenizers/benches/llama3.rs index f053cd16a..ca622e887 100644 --- a/tokenizers/benches/llama3.rs +++ b/tokenizers/benches/llama3.rs @@ -2,40 +2,49 @@ extern crate criterion; use criterion::{Criterion, Throughput}; -use tokenizers::Tokenizer; +use itertools::Itertools; use tokenizers::models::backtracking_bpe; +use tokenizers::Tokenizer; pub fn llama3(c: &mut Criterion) { let data = std::fs::read_to_string("data/big.txt").unwrap(); let mut group = c.benchmark_group("llama3-encode"); group.throughput(Throughput::Bytes(data.bytes().len() as u64)); - group.bench_function("llama3-offsets", |b| { - let tokenizer = - Tokenizer::from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", None).unwrap(); + + group.bench_function("llama3-backtracking", |b| { + let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); + let mut vocab = &mut tokenizer.get_vocab(false).clone().into_iter().collect::>(); // Convert HashMap into a Vec of (String, u32) tuples + // + vocab.sort_by(|a, b| a.1.cmp(&b.1)); + vocab.truncate(vocab.len().saturating_sub(3)); + let vocab: Vec<_> = vocab // Sort by u32 value + .into_iter() // IntoIterator to get the iterator of Vec + .map(|(tok, _)| Vec::from(tok.as_bytes())) + .collect(); + let model: backtracking_bpe::BacktrackingBpe = + backtracking_bpe::BacktrackingBpe::from_dictionary(vocab, None, None); + let tokenizer = Tokenizer::new(model); let data: Vec<_> = data.lines().collect(); let add_special_tokens = false; b.iter(|| { tokenizer - .encode_batch_char_offsets(criterion::black_box(data.clone()), add_special_tokens) + .encode_batch(criterion::black_box(data.clone()), add_special_tokens) .unwrap() }) }); - group.bench_function("llama3-nooffsets", |b| { - let tokenizer = - Tokenizer::from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", None).unwrap(); + + group.bench_function("llama3-offsets", |b| { + let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); let data: Vec<_> = data.lines().collect(); let add_special_tokens = false; b.iter(|| { tokenizer - .encode_batch(criterion::black_box(data.clone()), add_special_tokens) + .encode_batch_char_offsets(criterion::black_box(data.clone()), add_special_tokens) .unwrap() }) }); - - group.bench_function("llama3-backtracking", |b| { - let vocab = None; - let model: backtracking_bpe::BacktrackingBpe = backtracking_bpe::BacktrackingBpe::from_dictionary(vocab, None, None); - let tokenizer = Tokenizer::new(model); + group.bench_function("llama3-nooffsets", |b| { + let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); let data: Vec<_> = data.lines().collect(); let add_special_tokens = false; b.iter(|| { @@ -44,13 +53,14 @@ pub fn llama3(c: &mut Criterion) { .unwrap() }) }); + group.finish(); } criterion_group! { - name = bert_benches; + name = llama; config = Criterion::default().sample_size(10); targets = llama3 } -criterion_main!(bert_benches); +criterion_main!(llama); diff --git a/tokenizers/src/models/backtracking_bpe/mod.rs b/tokenizers/src/models/backtracking_bpe/mod.rs index 080cdc3df..0e2aece93 100644 --- a/tokenizers/src/models/backtracking_bpe/mod.rs +++ b/tokenizers/src/models/backtracking_bpe/mod.rs @@ -1,4 +1,4 @@ mod bitfield; mod model; -pub use model::*; \ No newline at end of file +pub use model::*; diff --git a/tokenizers/src/models/backtracking_bpe/model.rs b/tokenizers/src/models/backtracking_bpe/model.rs index 8e1cc3cf3..c9d813e2e 100644 --- a/tokenizers/src/models/backtracking_bpe/model.rs +++ b/tokenizers/src/models/backtracking_bpe/model.rs @@ -103,7 +103,7 @@ impl Default for BacktrackingBpeBuilder { } /// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. -#[derive(Serialize, PartialEq, Clone) ] +#[derive(Serialize, PartialEq, Clone)] pub struct BacktrackingBpe { /// All the decoded tokens concatenated into? used to build the aho corasick searchers all_tokens: Vec, @@ -529,11 +529,12 @@ impl BacktrackingBpe { }; for token_id in 0..bpe.num_tokens() as u32 { let bytes = bpe.token_bytes(token_id); + let strs = bytes.iter().map(|b|char::from(*b)).collect::>(); let tokens = bpe.encode_via_bitfield(bytes); assert_eq!( tokens, vec![token_id], - "token {token_id} with bytes {bytes:?} encodes to {tokens:?} instead of to itself" + "token {token_id} with bytes {bytes:?} (tokens {strs:?} encodes to {tokens:?} instead of to itself" ); } bpe @@ -682,7 +683,6 @@ impl BacktrackingBpe { let next_token = self.next_match(text); let mut enc = BacktrackState::new(text, next_token); while self.step(&mut enc).is_some() {} - println!("_______________________________"); enc.into_tokens() } diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index a55732e62..ecf9a5423 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -11,12 +11,12 @@ use std::path::{Path, PathBuf}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use crate::models::backtracking_bpe::BacktrackingBpe; use crate::models::bpe::{BpeTrainer, BPE}; use crate::models::unigram::{Unigram, UnigramTrainer}; use crate::models::wordlevel::{WordLevel, WordLevelTrainer}; use crate::models::wordpiece::{WordPiece, WordPieceTrainer}; use crate::{AddedToken, Model, Result, Token, Trainer}; -use crate::models::backtracking_bpe::BacktrackingBpe; /// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order /// of token ID, smallest to largest. @@ -62,7 +62,6 @@ impl<'a> Serialize for OrderedVocabIter<'a> { #[derive(Serialize, Debug, PartialEq, Clone)] #[serde(untagged)] pub enum ModelWrapper { - BPE(BPE), BacktrackingBpe(BacktrackingBpe), // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility @@ -154,7 +153,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.tokenize(tokens), Self::BPE(t) => t.tokenize(tokens), Self::Unigram(t) => t.tokenize(tokens), - Self::BacktrackingBpe(t) => t.tokenize(tokens), + Self::BacktrackingBpe(t) => t.tokenize(tokens), } } @@ -164,7 +163,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.token_to_id(token), Self::BPE(t) => t.token_to_id(token), Self::Unigram(t) => t.token_to_id(token), - Self::BacktrackingBpe(t) =>t.token_to_id(token), + Self::BacktrackingBpe(t) => t.token_to_id(token), } } @@ -174,7 +173,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.id_to_token(id), Self::BPE(t) => t.id_to_token(id), Self::Unigram(t) => t.id_to_token(id), - Self::BacktrackingBpe(t) => t.id_to_token(id), + Self::BacktrackingBpe(t) => t.id_to_token(id), } } @@ -184,7 +183,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.get_vocab(), Self::BPE(t) => t.get_vocab(), Self::Unigram(t) => t.get_vocab(), - Self::BacktrackingBpe(t) => t.get_vocab(), + Self::BacktrackingBpe(t) => t.get_vocab(), } } @@ -194,7 +193,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.get_vocab_size(), Self::BPE(t) => t.get_vocab_size(), Self::Unigram(t) => t.get_vocab_size(), - Self::BacktrackingBpe(t) => t.get_vocab_size(), + Self::BacktrackingBpe(t) => t.get_vocab_size(), } } @@ -204,7 +203,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.save(folder, name), Self::BPE(t) => t.save(folder, name), Self::Unigram(t) => t.save(folder, name), - Self::BacktrackingBpe(t) =>t.save(folder, name), + Self::BacktrackingBpe(t) => t.save(folder, name), } } @@ -214,7 +213,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.get_trainer().into(), Self::BPE(t) => t.get_trainer().into(), Self::Unigram(t) => t.get_trainer().into(), - Self::BacktrackingBpe(t) => t.get_trainer().into(), + Self::BacktrackingBpe(t) => t.get_trainer().into(), } } } From d334fb42dc0d8ee2b4b385c7619a2ca8e35153d2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 3 Jan 2025 16:22:19 +0100 Subject: [PATCH 05/11] add no pretokenizer bench --- tokenizers/benches/llama3.rs | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tokenizers/benches/llama3.rs b/tokenizers/benches/llama3.rs index ca622e887..f327c0dbb 100644 --- a/tokenizers/benches/llama3.rs +++ b/tokenizers/benches/llama3.rs @@ -4,6 +4,7 @@ extern crate criterion; use criterion::{Criterion, Throughput}; use itertools::Itertools; use tokenizers::models::backtracking_bpe; +use tokenizers::PreTokenizerWrapper; use tokenizers::Tokenizer; pub fn llama3(c: &mut Criterion) { @@ -12,7 +13,7 @@ pub fn llama3(c: &mut Criterion) { group.throughput(Throughput::Bytes(data.bytes().len() as u64)); group.bench_function("llama3-backtracking", |b| { - let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); + let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); let mut vocab = &mut tokenizer.get_vocab(false).clone().into_iter().collect::>(); // Convert HashMap into a Vec of (String, u32) tuples // vocab.sort_by(|a, b| a.1.cmp(&b.1)); @@ -23,7 +24,30 @@ pub fn llama3(c: &mut Criterion) { .collect(); let model: backtracking_bpe::BacktrackingBpe = backtracking_bpe::BacktrackingBpe::from_dictionary(vocab, None, None); - let tokenizer = Tokenizer::new(model); + tokenizer.with_model(model); + let data: Vec<_> = data.lines().collect(); + let add_special_tokens = false; + b.iter(|| { + tokenizer + .encode_batch(criterion::black_box(data.clone()), add_special_tokens) + .unwrap() + }) + }); + + group.bench_function("llama3-backtracking-no-pretok", |b| { + let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); + let mut vocab = &mut tokenizer.get_vocab(false).clone().into_iter().collect::>(); // Convert HashMap into a Vec of (String, u32) tuples + // + vocab.sort_by(|a, b| a.1.cmp(&b.1)); + vocab.truncate(vocab.len().saturating_sub(3)); + let vocab: Vec<_> = vocab // Sort by u32 value + .into_iter() // IntoIterator to get the iterator of Vec + .map(|(tok, _)| Vec::from(tok.as_bytes())) + .collect(); + let model: backtracking_bpe::BacktrackingBpe = + backtracking_bpe::BacktrackingBpe::from_dictionary(vocab, None, None); + tokenizer.with_model(model); + tokenizer.with_pre_tokenizer(None::); let data: Vec<_> = data.lines().collect(); let add_special_tokens = false; b.iter(|| { From a3ed0c3e8d08a1ebcf33736a58e70dabb09530ba Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 3 Jan 2025 17:28:42 +0100 Subject: [PATCH 06/11] push latest changes --- tokenizers/benches/llama3.rs | 10 +- tokenizers/src/models/backtracking_bpe/mod.rs | 1 + .../src/models/backtracking_bpe/model.rs | 95 +++++---- .../models/backtracking_bpe/serialization.rs | 180 ++++++++++++++++++ tokenizers/src/models/bpe/mod.rs | 2 +- tokenizers/src/models/mod.rs | 4 +- 6 files changed, 246 insertions(+), 46 deletions(-) create mode 100644 tokenizers/src/models/backtracking_bpe/serialization.rs diff --git a/tokenizers/benches/llama3.rs b/tokenizers/benches/llama3.rs index f327c0dbb..05efcbd62 100644 --- a/tokenizers/benches/llama3.rs +++ b/tokenizers/benches/llama3.rs @@ -29,7 +29,7 @@ pub fn llama3(c: &mut Criterion) { let add_special_tokens = false; b.iter(|| { tokenizer - .encode_batch(criterion::black_box(data.clone()), add_special_tokens) + .encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens) .unwrap() }) }); @@ -52,22 +52,22 @@ pub fn llama3(c: &mut Criterion) { let add_special_tokens = false; b.iter(|| { tokenizer - .encode_batch(criterion::black_box(data.clone()), add_special_tokens) + .encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens) .unwrap() }) }); - group.bench_function("llama3-offsets", |b| { + group.bench_function("llama3-encode_batch_fast", |b| { let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); let data: Vec<_> = data.lines().collect(); let add_special_tokens = false; b.iter(|| { tokenizer - .encode_batch_char_offsets(criterion::black_box(data.clone()), add_special_tokens) + .encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens) .unwrap() }) }); - group.bench_function("llama3-nooffsets", |b| { + group.bench_function("llama3-encode_batch", |b| { let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); let data: Vec<_> = data.lines().collect(); let add_special_tokens = false; diff --git a/tokenizers/src/models/backtracking_bpe/mod.rs b/tokenizers/src/models/backtracking_bpe/mod.rs index 0e2aece93..2c17cbf46 100644 --- a/tokenizers/src/models/backtracking_bpe/mod.rs +++ b/tokenizers/src/models/backtracking_bpe/mod.rs @@ -1,4 +1,5 @@ mod bitfield; mod model; +mod serialization; pub use model::*; diff --git a/tokenizers/src/models/backtracking_bpe/model.rs b/tokenizers/src/models/backtracking_bpe/model.rs index c9d813e2e..2a13c4ac2 100644 --- a/tokenizers/src/models/backtracking_bpe/model.rs +++ b/tokenizers/src/models/backtracking_bpe/model.rs @@ -1,6 +1,6 @@ use super::bitfield::BitField; use super::{super::bpe::trainer::BpeTrainer, super::bpe::Error, super::OrderedVocabIter}; -use crate::models::bpe::BPE; +use crate::models::bpe::{MergeMap, Pair, BPE}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::iter::ResultShunt; use aneubeck_daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder}; @@ -24,6 +24,7 @@ pub type Vocab = HashMap; type VocabR = HashMap; pub type Merges = Vec<(String, String)>; + /// This can be thought of as a lazy variation of the dynamic programming approach. /// It only computes those states which have to be visited in order to compute the tokenization /// for a given input text. @@ -102,8 +103,9 @@ impl Default for BacktrackingBpeBuilder { } } + /// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. -#[derive(Serialize, PartialEq, Clone)] +#[derive(PartialEq, Clone)] pub struct BacktrackingBpe { /// All the decoded tokens concatenated into? used to build the aho corasick searchers all_tokens: Vec, @@ -122,21 +124,18 @@ pub struct BacktrackingBpe { // serialize_with = "serialize_daac", // deserialize_with = "deserialize_daac" // )] - #[serde(skip)] longest_searcher: DoubleArrayAhoCorasick, /// An aho corasick automaton to find ALL tokens in a byte sequence. // #[serde( // serialize_with = "serialize_daac", // deserialize_with = "deserialize_daac" // )] - #[serde(skip)] pub(crate) overlapping_searcher: DoubleArrayAhoCorasick, /// An aho corasick automaton to find ALL tokens in a byte sequence which is being processed in reverse order. // #[serde( // serialize_with = "serialize_daac", // deserialize_with = "deserialize_daac" // )] - #[serde(skip)] pub(crate) overlapping_searcher_rev: DoubleArrayAhoCorasick, /// Mapping from a token to the next longest prefix token. /// This is in principle information represented by the AhoCorasick automaton. @@ -145,12 +144,14 @@ pub struct BacktrackingBpe { next_prefix_match: Vec, /// Hash factor used to prevent hash collisions. hash_factor: u64, - vocab: Vocab, - vocab_r: VocabR, + pub vocab: Vocab, + pub vocab_r: VocabR, unk_token: Option, + pub merges: MergeMap, } use std::fmt; + // Manually implement the Debug trait to exclude the `cache` field impl fmt::Debug for BacktrackingBpe { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -430,7 +431,7 @@ impl BacktrackingBpe { /// to prevent repeating the cost of computing the hash factor and encoding. pub fn from_dictionary( tokens: impl IntoIterator>, - merges: Option>, + merges: Option, hash_factor: Option, ) -> Self { let hash_factor = hash_factor @@ -440,6 +441,7 @@ impl BacktrackingBpe { let mut all_tokens_rev = Vec::new(); let mut token_starts = vec![0]; let mut bytes_hash_to_token = FnvHashMap::default(); + let mut merge_map :HashMap = HashMap::new(); for (i, token) in tokens.into_iter().enumerate() { bytes_hash_to_token.insert(hash_bytes(&token, hash_factor), i as u32); all_tokens_rev.extend(token.iter().copied().rev()); @@ -463,36 +465,6 @@ impl BacktrackingBpe { }) .collect(); - let mut split_table = vec![]; - let mut pair_lookup = FnvHashMap::default(); - - // Reverse engineer the merge/split table. - for (id, token) in token_iter(&all_tokens, &token_starts).enumerate() { - let mut token1 = next_prefix_match[id]; - while token1 != u32::MAX { - let rest = &token[token_range(&token_starts, token1).len()..]; - if let Some(token2) = find_token_by_bytes( - &all_tokens, - &token_starts, - &bytes_hash_to_token, - rest, - hash_factor, - ) { - if token1 < id as u32 - && token2 < id as u32 - && is_valid_token_pair(&pair_lookup, &split_table, token1, token2) - { - pair_lookup.insert((token1, token2), id as u32); - split_table.push((token1, token2)); - break; - } - } - token1 = next_prefix_match[token1 as usize]; - } - if token1 == u32::MAX { - split_table.push((id as u32, id as u32)); - } - } let vocab: HashMap = token_iter(&all_tokens, &token_starts) .enumerate() .map(|(id, item)| { @@ -512,6 +484,52 @@ impl BacktrackingBpe { }) .collect(); + let mut split_table = vec![]; + let mut pair_lookup = FnvHashMap::default(); + + if let Some(ref merges) = merges { + for (id, pair) in merges.into_iter().enumerate(){ + let token1 = vocab[&pair.0.clone()]; + let token2 = vocab[&pair.1.clone()]; + pair_lookup.insert((token1, token2), id as u32); + split_table.push((token1, token2)); + merge_map.insert(Pair::from(pair), (id as u32, id as u32)); // TODO wrong + }; + } else { + // Reverse engineer the merge/split table. + for (id, token) in token_iter(&all_tokens, &token_starts).enumerate() { + let mut token1 = next_prefix_match[id]; + while token1 != u32::MAX { + let rest = &token[token_range(&token_starts, token1).len()..]; + if let Some(token2) = find_token_by_bytes( + &all_tokens, + &token_starts, + &bytes_hash_to_token, + rest, + hash_factor, + ) { + if token1 < id as u32 + && token2 < id as u32 + && is_valid_token_pair(&pair_lookup, &split_table, token1, token2) + { + pair_lookup.insert((token1, token2), id as u32); + split_table.push((token1, token2)); + let str_token1 = unsafe { String::from_utf8_unchecked(Vec::from(&all_tokens[token_range(&token_starts, token1)]))}; + let str_token2 = unsafe { String::from_utf8_unchecked(Vec::from(&all_tokens[token_range(&token_starts, token2)]))}; + merge_map.insert(Pair::from(&(str_token1,str_token2)), (id as u32, id as u32)); // TODO wrong + break; + } + } + token1 = next_prefix_match[token1 as usize]; + } + if token1 == u32::MAX { + split_table.push((id as u32, id as u32)); + } + } + }; + + + let bpe = Self { all_tokens, token_starts, @@ -526,6 +544,7 @@ impl BacktrackingBpe { unk_token: None, vocab, vocab_r, + merges: merge_map }; for token_id in 0..bpe.num_tokens() as u32 { let bytes = bpe.token_bytes(token_id); diff --git a/tokenizers/src/models/backtracking_bpe/serialization.rs b/tokenizers/src/models/backtracking_bpe/serialization.rs new file mode 100644 index 000000000..3218faa38 --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/serialization.rs @@ -0,0 +1,180 @@ +use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BacktrackingBpe, BacktrackingBpeBuilder, super::bpe::Pair }; +use serde::{ + de::{Error, MapAccess, Visitor}, + ser::SerializeStruct, + Deserialize, Deserializer, Serialize, Serializer, +}; +use std::collections::HashMap; + +impl Serialize for BacktrackingBpe { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut model = serializer.serialize_struct("BPE", 8)?; + + // Start by small fields + model.serialize_field("type", "BPE")?; + + // Then the large ones + let mut merges: Vec<(&Pair, &u32)> = self + .merges + .iter() + .map(|(pair, (rank, _))| (pair, rank)) + .collect(); + merges.sort_unstable_by_key(|k| *k.1); + let merges = merges + .into_iter() + .map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone())) + .collect::>(); + let ordered_vocab = OrderedVocabIter::new(&self.vocab_r); + + model.serialize_field("vocab", &ordered_vocab)?; + model.serialize_field("merges", &merges)?; + + model.end() + } +} + +impl<'de> Deserialize<'de> for BacktrackingBpe { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_struct( + "BPE", + &[ + "type", + "dropout", + "unk_token", + "vocab", + "merges", + ], + BPEVisitor, + ) + } +} + +struct BPEVisitor; +impl<'de> Visitor<'de> for BPEVisitor { + type Value = BacktrackingBpe; + + fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(fmt, "struct BPE") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: MapAccess<'de>, + { + let mut builder =BacktrackingBpeBuilder::new(); + let mut vocab: Option> = None; + + #[derive(Debug, Deserialize)] + #[serde(untagged)] + enum MergeType { + Tuple(Vec<(String, String)>), + Legacy(Vec), + } + let mut merges: Option = None; + while let Some(key) = map.next_key::()? { + match key.as_ref() { + "dropout" => { + if let Some(dropout) = map.next_value()? { + builder = builder.dropout(dropout); + } + } + "unk_token" => { + if let Some(unk) = map.next_value()? { + builder = builder.unk_token(unk); + } + } + "vocab" => vocab = Some(map.next_value()?), + "merges" => merges = Some(map.next_value()?), + "type" => match map.next_value()? { + "BPE" => {} + u => { + return Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(u), + &"BacktrackingBpe", + )) + } + }, + _ => {} + } + } + if let (Some(vocab), Some(merges)) = (vocab, merges) { + let merges = match merges { + MergeType::Tuple(merges) => merges, + MergeType::Legacy(merges) => { + convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)? + } + }; + builder = builder.vocab_and_merges(vocab, merges); + Ok(builder.build().map_err(Error::custom)?) + } else { + Err(Error::custom("Missing vocab/merges")) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::models::bpe::Vocab; + + #[test] + fn test_serialization() { + let vocab: Vocab = [ + ("".into(), 0), + ("a".into(), 1), + ("b".into(), 2), + ("ab".into(), 3), + ] + .iter() + .cloned() + .collect(); + let bpe = BacktrackingBpeBuilder::default() + .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())]) + .unk_token("".to_string()) + .build() + .unwrap(); + + let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#; + let legacy = serde_json::from_str(legacy).unwrap(); + assert_eq!(bpe, legacy); + + let data = serde_json::to_string(&bpe).unwrap(); + assert_eq!( + data, + r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"# + ); + let reconstructed = serde_json::from_str(&data).unwrap(); + assert_eq!(bpe, reconstructed); + + // With a space in the token + let vocab: Vocab = [ + ("".into(), 0), + ("a".into(), 1), + ("b c d".into(), 2), + ("ab c d".into(), 3), + ] + .iter() + .cloned() + .collect(); + let bpe =BacktrackingBpeBuilder::default() + .vocab_and_merges(vocab, vec![("a".to_string(), "b c d".to_string())]) + .unk_token("".to_string()) + .build() + .unwrap(); + let data = serde_json::to_string(&bpe).unwrap(); + assert_eq!( + data, + r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b c d":2,"ab c d":3},"merges":[["a","b c d"]]}"# + ); + let reconstructed = serde_json::from_str(&data).unwrap(); + assert_eq!(bpe, reconstructed); + } + + +} diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index f0d40b2df..6bd51cb1b 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -6,7 +6,7 @@ mod serialization; pub mod trainer; mod word; -type Pair = (u32, u32); +pub(crate) type Pair = (u32, u32); /// Errors that can be encountered while using or constructing a `BPE` model. #[derive(thiserror::Error, Debug)] diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index ecf9a5423..da8adddcf 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -101,7 +101,7 @@ impl<'de> Deserialize<'de> for ModelWrapper { #[derive(Deserialize)] #[serde(untagged)] pub enum ModelUntagged { - BPE(BPE), + BPE(BacktrackingBpe), // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility // with the versions not including the "type"), since WordLevel is a subset of WordPiece WordPiece(WordPiece), @@ -128,7 +128,7 @@ impl<'de> Deserialize<'de> for ModelWrapper { ModelHelper::Legacy(value) => { let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?; match untagged { - ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe), + ModelUntagged::BPE(bpe) => ModelWrapper::BacktrackingBpe(bpe), ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe), ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe), ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe), From 7c9e5349f07a8dd349283ca56e63fa36c372cd04 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 4 Jan 2025 09:11:23 +0100 Subject: [PATCH 07/11] updates --- tokenizers/benches/llama3.rs | 20 +- .../backtracking_bpe/backtracking_state.rs | 49 + tokenizers/src/models/backtracking_bpe/mod.rs | 3 + .../src/models/backtracking_bpe/model.rs | 70 +- .../models/backtracking_bpe/serialization.rs | 50 +- .../src/models/backtracking_bpe/trainer.rs | 856 ++++++++++++++++++ tokenizers/src/models/bpe/mod.rs | 4 +- tokenizers/src/models/bpe/word.rs | 10 +- tokenizers/src/models/mod.rs | 22 +- 9 files changed, 981 insertions(+), 103 deletions(-) create mode 100644 tokenizers/src/models/backtracking_bpe/backtracking_state.rs create mode 100644 tokenizers/src/models/backtracking_bpe/trainer.rs diff --git a/tokenizers/benches/llama3.rs b/tokenizers/benches/llama3.rs index 05efcbd62..f7104d757 100644 --- a/tokenizers/benches/llama3.rs +++ b/tokenizers/benches/llama3.rs @@ -14,11 +14,15 @@ pub fn llama3(c: &mut Criterion) { group.bench_function("llama3-backtracking", |b| { let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); - let mut vocab = &mut tokenizer.get_vocab(false).clone().into_iter().collect::>(); // Convert HashMap into a Vec of (String, u32) tuples - // + let mut vocab = &mut tokenizer + .get_vocab(false) + .clone() + .into_iter() + .collect::>(); // Convert HashMap into a Vec of (String, u32) tuples + // vocab.sort_by(|a, b| a.1.cmp(&b.1)); vocab.truncate(vocab.len().saturating_sub(3)); - let vocab: Vec<_> = vocab // Sort by u32 value + let vocab: Vec<_> = vocab // Sort by u32 value .into_iter() // IntoIterator to get the iterator of Vec .map(|(tok, _)| Vec::from(tok.as_bytes())) .collect(); @@ -36,11 +40,15 @@ pub fn llama3(c: &mut Criterion) { group.bench_function("llama3-backtracking-no-pretok", |b| { let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); - let mut vocab = &mut tokenizer.get_vocab(false).clone().into_iter().collect::>(); // Convert HashMap into a Vec of (String, u32) tuples - // + let mut vocab = &mut tokenizer + .get_vocab(false) + .clone() + .into_iter() + .collect::>(); // Convert HashMap into a Vec of (String, u32) tuples + // vocab.sort_by(|a, b| a.1.cmp(&b.1)); vocab.truncate(vocab.len().saturating_sub(3)); - let vocab: Vec<_> = vocab // Sort by u32 value + let vocab: Vec<_> = vocab // Sort by u32 value .into_iter() // IntoIterator to get the iterator of Vec .map(|(tok, _)| Vec::from(tok.as_bytes())) .collect(); diff --git a/tokenizers/src/models/backtracking_bpe/backtracking_state.rs b/tokenizers/src/models/backtracking_bpe/backtracking_state.rs new file mode 100644 index 000000000..98fe15381 --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/backtracking_state.rs @@ -0,0 +1,49 @@ +use super::bitfield::BitField; + +/// This can be thought of as a lazy variation of the dynamic programming approach. +/// It only computes those states which have to be visited in order to compute the tokenization +/// for a given input text. +/// It keeps track of visited states in a bitfield and only remembers the tokenization +/// of the currently processed dynamic programming state. +/// +/// The biggest downside of this approach is that the search for the longest leftmost match (the firt token?) +/// has to be reset at every (backtracking) step which is still a net win in practice compared to other approaches. +#[derive(Clone, PartialEq)] +pub struct BacktrackState<'a> { + pub(crate) text: &'a [u8], + pub(crate) tokens: Vec, // len of the tezt / 3 + pub(crate) next_token: Option, // bpe.next_match(text) wich is longest_searcher.leftmost_find_iter(text)'s first match value + pub(crate) pos: usize, // current pos in the text? + pub(crate) bitfield: BitField, // keeps track of token boundaries? keeps track of all the valid tokenization positions and making the runtime linear in the input length. +} + +impl<'a> BacktrackState<'a> { + pub(crate) fn new(text: &'a [u8], next_token: Option) -> Self { + Self::with_capacity(text, next_token, text.len() / 3) + } + + pub(crate) fn with_capacity(text: &'a [u8], next_token: Option, cap: usize) -> Self { + Self { + text, + tokens: Vec::with_capacity(cap), + next_token, + pos: 0, + bitfield: BitField::new(text.len() + 1), + } + } + pub(crate) fn count(&self) -> usize { + self.tokens.len() + } + + pub(crate) fn pos(&self) -> usize { + self.pos + } + + pub(crate) fn last_token(&self) -> Option { + self.tokens.last().copied() + } + + pub(crate) fn into_tokens(self) -> Vec { + self.tokens + } +} diff --git a/tokenizers/src/models/backtracking_bpe/mod.rs b/tokenizers/src/models/backtracking_bpe/mod.rs index 2c17cbf46..16f75787e 100644 --- a/tokenizers/src/models/backtracking_bpe/mod.rs +++ b/tokenizers/src/models/backtracking_bpe/mod.rs @@ -1,5 +1,8 @@ +mod backtracking_state; mod bitfield; mod model; mod serialization; +pub mod trainer; pub use model::*; +pub use trainer::*; diff --git a/tokenizers/src/models/backtracking_bpe/model.rs b/tokenizers/src/models/backtracking_bpe/model.rs index 2a13c4ac2..020447b1e 100644 --- a/tokenizers/src/models/backtracking_bpe/model.rs +++ b/tokenizers/src/models/backtracking_bpe/model.rs @@ -24,54 +24,7 @@ pub type Vocab = HashMap; type VocabR = HashMap; pub type Merges = Vec<(String, String)>; - -/// This can be thought of as a lazy variation of the dynamic programming approach. -/// It only computes those states which have to be visited in order to compute the tokenization -/// for a given input text. -/// It keeps track of visited states in a bitfield and only remembers the tokenization -/// of the currently processed dynamic programming state. -/// -/// The biggest downside of this approach is that the search for the longest leftmost match (the firt token?) -/// has to be reset at every (backtracking) step which is still a net win in practice compared to other approaches. -#[derive(Clone, PartialEq)] -pub(crate) struct BacktrackState<'a> { - text: &'a [u8], - tokens: Vec, // len of the tezt / 3 - next_token: Option, // bpe.next_match(text) wich is longest_searcher.leftmost_find_iter(text)'s first match value - pos: usize, // current pos in the text? - bitfield: BitField, // keeps track of token boundaries? keeps track of all the valid tokenization positions and making the runtime linear in the input length. -} - -impl<'a> BacktrackState<'a> { - pub(crate) fn new(text: &'a [u8], next_token: Option) -> Self { - Self::with_capacity(text, next_token, text.len() / 3) - } - - pub(crate) fn with_capacity(text: &'a [u8], next_token: Option, cap: usize) -> Self { - Self { - text, - tokens: Vec::with_capacity(cap), - next_token, - pos: 0, - bitfield: BitField::new(text.len() + 1), - } - } - pub(crate) fn count(&self) -> usize { - self.tokens.len() - } - - pub(crate) fn pos(&self) -> usize { - self.pos - } - - pub(crate) fn last_token(&self) -> Option { - self.tokens.last().copied() - } - - pub(crate) fn into_tokens(self) -> Vec { - self.tokens - } -} +use super::backtracking_state::BacktrackState; struct Config { files: Option<(String, String)>, @@ -103,7 +56,6 @@ impl Default for BacktrackingBpeBuilder { } } - /// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. #[derive(PartialEq, Clone)] pub struct BacktrackingBpe { @@ -441,7 +393,7 @@ impl BacktrackingBpe { let mut all_tokens_rev = Vec::new(); let mut token_starts = vec![0]; let mut bytes_hash_to_token = FnvHashMap::default(); - let mut merge_map :HashMap = HashMap::new(); + let mut merge_map: HashMap = HashMap::new(); for (i, token) in tokens.into_iter().enumerate() { bytes_hash_to_token.insert(hash_bytes(&token, hash_factor), i as u32); all_tokens_rev.extend(token.iter().copied().rev()); @@ -488,13 +440,14 @@ impl BacktrackingBpe { let mut pair_lookup = FnvHashMap::default(); if let Some(ref merges) = merges { - for (id, pair) in merges.into_iter().enumerate(){ + for (id, pair) in merges.into_iter().enumerate() { let token1 = vocab[&pair.0.clone()]; let token2 = vocab[&pair.1.clone()]; pair_lookup.insert((token1, token2), id as u32); split_table.push((token1, token2)); - merge_map.insert(Pair::from(pair), (id as u32, id as u32)); // TODO wrong - }; + merge_map.insert(Pair::from((token1, token2)), (id as u32, id as u32)); + // TODO wrong + } } else { // Reverse engineer the merge/split table. for (id, token) in token_iter(&all_tokens, &token_starts).enumerate() { @@ -514,9 +467,7 @@ impl BacktrackingBpe { { pair_lookup.insert((token1, token2), id as u32); split_table.push((token1, token2)); - let str_token1 = unsafe { String::from_utf8_unchecked(Vec::from(&all_tokens[token_range(&token_starts, token1)]))}; - let str_token2 = unsafe { String::from_utf8_unchecked(Vec::from(&all_tokens[token_range(&token_starts, token2)]))}; - merge_map.insert(Pair::from(&(str_token1,str_token2)), (id as u32, id as u32)); // TODO wrong + merge_map.insert(Pair::from((token1, token2)), (id as u32, id as u32)); break; } } @@ -528,8 +479,6 @@ impl BacktrackingBpe { } }; - - let bpe = Self { all_tokens, token_starts, @@ -544,11 +493,11 @@ impl BacktrackingBpe { unk_token: None, vocab, vocab_r, - merges: merge_map + merges: merge_map, }; for token_id in 0..bpe.num_tokens() as u32 { let bytes = bpe.token_bytes(token_id); - let strs = bytes.iter().map(|b|char::from(*b)).collect::>(); + let strs = bytes.iter().map(|b| char::from(*b)).collect::>(); let tokens = bpe.encode_via_bitfield(bytes); assert_eq!( tokens, @@ -751,6 +700,7 @@ impl BacktrackingBpe { ) -> impl Iterator + 'a { word.into_iter() .map(move |id| Token::new(*id, self.vocab_r[&id].clone(), (0usize, 0usize))) + // TODO offsets should be easy to integrate as well! } } impl Model for BacktrackingBpe { diff --git a/tokenizers/src/models/backtracking_bpe/serialization.rs b/tokenizers/src/models/backtracking_bpe/serialization.rs index 3218faa38..214d860d1 100644 --- a/tokenizers/src/models/backtracking_bpe/serialization.rs +++ b/tokenizers/src/models/backtracking_bpe/serialization.rs @@ -1,4 +1,7 @@ -use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BacktrackingBpe, BacktrackingBpeBuilder, super::bpe::Pair }; +use super::{ + super::bpe::Pair, super::OrderedVocabIter, convert_merges_to_hashmap, BacktrackingBpe, + BacktrackingBpeBuilder, +}; use serde::{ de::{Error, MapAccess, Visitor}, ser::SerializeStruct, @@ -11,7 +14,7 @@ impl Serialize for BacktrackingBpe { where S: Serializer, { - let mut model = serializer.serialize_struct("BPE", 8)?; + let mut model = serializer.serialize_struct("BacktrackingBpe", 8)?; // Start by small fields model.serialize_field("type", "BPE")?; @@ -42,32 +45,26 @@ impl<'de> Deserialize<'de> for BacktrackingBpe { D: Deserializer<'de>, { deserializer.deserialize_struct( - "BPE", - &[ - "type", - "dropout", - "unk_token", - "vocab", - "merges", - ], - BPEVisitor, + "BacktrackingBpe", + &["type", "dropout", "unk_token", "vocab", "merges"], + BacktrackingBpeVisitor, ) } } -struct BPEVisitor; -impl<'de> Visitor<'de> for BPEVisitor { +struct BacktrackingBpeVisitor; +impl<'de> Visitor<'de> for BacktrackingBpeVisitor { type Value = BacktrackingBpe; fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(fmt, "struct BPE") + write!(fmt, "struct BacktrackingBpe") } fn visit_map(self, mut map: V) -> std::result::Result where V: MapAccess<'de>, { - let mut builder =BacktrackingBpeBuilder::new(); + let mut builder = BacktrackingBpeBuilder::new(); let mut vocab: Option> = None; #[derive(Debug, Deserialize)] @@ -92,6 +89,7 @@ impl<'de> Visitor<'de> for BPEVisitor { "vocab" => vocab = Some(map.next_value()?), "merges" => merges = Some(map.next_value()?), "type" => match map.next_value()? { + "BacktrackingBpe" => {} "BPE" => {} u => { return Err(serde::de::Error::invalid_value( @@ -131,23 +129,23 @@ mod test { ("b".into(), 2), ("ab".into(), 3), ] - .iter() - .cloned() - .collect(); + .iter() + .cloned() + .collect(); let bpe = BacktrackingBpeBuilder::default() .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())]) .unk_token("".to_string()) .build() .unwrap(); - let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#; + let legacy = r#"{"type":"BacktrackingBpe","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#; let legacy = serde_json::from_str(legacy).unwrap(); assert_eq!(bpe, legacy); let data = serde_json::to_string(&bpe).unwrap(); assert_eq!( data, - r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"# + r#"{"type":"BacktrackingBpe","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"# ); let reconstructed = serde_json::from_str(&data).unwrap(); assert_eq!(bpe, reconstructed); @@ -159,10 +157,10 @@ mod test { ("b c d".into(), 2), ("ab c d".into(), 3), ] - .iter() - .cloned() - .collect(); - let bpe =BacktrackingBpeBuilder::default() + .iter() + .cloned() + .collect(); + let bpe = BacktrackingBpeBuilder::default() .vocab_and_merges(vocab, vec![("a".to_string(), "b c d".to_string())]) .unk_token("".to_string()) .build() @@ -170,11 +168,9 @@ mod test { let data = serde_json::to_string(&bpe).unwrap(); assert_eq!( data, - r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b c d":2,"ab c d":3},"merges":[["a","b c d"]]}"# + r#"{"type":"BacktrackingBpe","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b c d":2,"ab c d":3},"merges":[["a","b c d"]]}"# ); let reconstructed = serde_json::from_str(&data).unwrap(); assert_eq!(bpe, reconstructed); } - - } diff --git a/tokenizers/src/models/backtracking_bpe/trainer.rs b/tokenizers/src/models/backtracking_bpe/trainer.rs new file mode 100644 index 000000000..94270c374 --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/trainer.rs @@ -0,0 +1,856 @@ +#![allow(clippy::map_entry)] + +use super::{ + super::bpe::Pair, super::bpe::WithFirstLastIterator, super::bpe::Word, BacktrackingBpe, +}; +use crate::parallelism::*; +use crate::tokenizer::{AddedToken, Result, Trainer}; +use crate::utils::progress::{ProgressBar, ProgressStyle}; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::collections::{BinaryHeap, HashMap, HashSet}; + +#[derive(Debug, Eq)] +struct Merge { + pair: Pair, + count: u64, + pos: HashSet, +} +impl PartialEq for Merge { + fn eq(&self, other: &Self) -> bool { + self.count == other.count && self.pair == other.pair + } +} +impl PartialOrd for Merge { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for Merge { + fn cmp(&self, other: &Self) -> Ordering { + if self.count != other.count { + self.count.cmp(&other.count) + } else { + // Here we want ascending order + other.pair.cmp(&self.pair) + } + } +} + +struct Config { + min_frequency: u64, + vocab_size: usize, + show_progress: bool, + special_tokens: Vec, + limit_alphabet: Option, + initial_alphabet: HashSet, + continuing_subword_prefix: Option, + end_of_word_suffix: Option, + max_token_length: Option, +} + +/// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom +/// configuration. +pub struct BpeTrainerBuilder { + config: Config, +} + +impl Default for BpeTrainerBuilder { + fn default() -> Self { + Self { + config: Config { + min_frequency: 0, + vocab_size: 30000, + show_progress: true, + special_tokens: vec![], + limit_alphabet: None, + initial_alphabet: HashSet::new(), + continuing_subword_prefix: None, + end_of_word_suffix: None, + max_token_length: None, + }, + } + } +} + +impl BpeTrainerBuilder { + /// Constructs a new `BpeTrainerBuilder` + pub fn new() -> Self { + Self::default() + } + + /// Set the expected minimum frequency + #[must_use] + pub fn min_frequency(mut self, frequency: u64) -> Self { + self.config.min_frequency = frequency; + self + } + + /// Set the vocabulary size + #[must_use] + pub fn vocab_size(mut self, size: usize) -> Self { + self.config.vocab_size = size; + self + } + + /// Set whether to show progress + #[must_use] + pub fn show_progress(mut self, show: bool) -> Self { + self.config.show_progress = show; + self + } + + /// Set the special tokens + #[must_use] + pub fn special_tokens(mut self, tokens: Vec) -> Self { + self.config.special_tokens = tokens; + self + } + + /// Set whether to limit the alphabet + #[must_use] + pub fn limit_alphabet(mut self, limit: usize) -> Self { + self.config.limit_alphabet = Some(limit); + self + } + + /// Set the initial alphabet + #[must_use] + pub fn initial_alphabet(mut self, alphabet: HashSet) -> Self { + self.config.initial_alphabet = alphabet; + self + } + + /// Set the continuing_subword_prefix + #[must_use] + pub fn continuing_subword_prefix(mut self, prefix: String) -> Self { + self.config.continuing_subword_prefix = Some(prefix); + self + } + + /// Set the end_of_word_suffix + #[must_use] + pub fn end_of_word_suffix(mut self, suffix: String) -> Self { + self.config.end_of_word_suffix = Some(suffix); + self + } + /// Set max_token_length + #[must_use] + pub fn max_token_length(mut self, max_token_length: Option) -> Self { + self.config.max_token_length = max_token_length; + self + } + + /// Constructs the final BpeTrainer + pub fn build(self) -> BacktrackingBpeTrainer { + BacktrackingBpeTrainer { + min_frequency: self.config.min_frequency, + vocab_size: self.config.vocab_size, + show_progress: self.config.show_progress, + special_tokens: self.config.special_tokens, + limit_alphabet: self.config.limit_alphabet, + initial_alphabet: self.config.initial_alphabet, + continuing_subword_prefix: self.config.continuing_subword_prefix, + end_of_word_suffix: self.config.end_of_word_suffix, + max_token_length: self.config.max_token_length, + words: HashMap::new(), + } + } +} + +/// In charge of training a `BacktrackingBpe` model +/// +/// # Examples +/// +/// ``` +/// use tokenizers::tokenizer::Trainer; +/// use tokenizers::models::backtracking_bpe::{BacktrackingBpe, BacktrackingBpeTrainer}; +/// +/// let sequences = vec![ "Hello", "World" ]; +/// +/// let mut trainer = BacktrackingBpeTrainer::default(); +/// trainer.feed(sequences.iter(), |s| Ok(vec![s.to_owned()])); +/// +/// let mut model = BacktrackingBpe::default(); +/// let special_tokens = trainer.train(&mut model).unwrap(); +/// ``` +#[non_exhaustive] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] +pub struct BacktrackingBpeTrainer { + /// The minimum frequency a pair must have to produce a merge operation + pub min_frequency: u64, + /// The target vocabulary size + pub vocab_size: usize, + /// Whether to show progress while training + pub show_progress: bool, + /// A list of special tokens that the model should know of + pub special_tokens: Vec, + /// Whether to limit the number of initial tokens that can be kept before computing merges + pub limit_alphabet: Option, + /// The initial alphabet we want absolutely to include. This allows to cover + /// some characters that are not necessarily in the training set + pub initial_alphabet: HashSet, + /// An optional prefix to use on any subword that exist only behind another one + pub continuing_subword_prefix: Option, + /// An optional suffix to caracterize and end-of-word subword + pub end_of_word_suffix: Option, + /// An optional parameter to limit the max length of any single token + pub max_token_length: Option, + + words: HashMap, +} + +impl Default for BacktrackingBpeTrainer { + fn default() -> Self { + Self::builder().build() + } +} + +impl BacktrackingBpeTrainer { + pub fn new(min_frequency: u64, vocab_size: usize) -> Self { + Self { + min_frequency, + vocab_size, + ..Default::default() + } + } + + pub fn builder() -> BpeTrainerBuilder { + BpeTrainerBuilder::new() + } + + /// Setup a progress bar if asked to show progress + fn setup_progress(&self) -> Option { + if self.show_progress { + let p = ProgressBar::new(0); + p.set_style( + ProgressStyle::default_bar() + .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {pos:<9!}/{len:>9!}") + .expect("Invalid progress template"), + ); + Some(p) + } else { + None + } + } + + /// Set the progress bar in the finish state + fn finalize_progress(&self, p: &Option, final_len: usize) { + if let Some(p) = p { + p.set_length(final_len as u64); + p.finish(); + println!(); + } + } + + /// Update the progress bar with the new provided length and message + fn update_progress(&self, p: &Option, len: usize, message: &'static str) { + if let Some(p) = p { + p.set_message(message); + p.set_length(len as u64); + p.reset(); + } + } + + /// Add the provided special tokens to the initial vocabulary + fn add_special_tokens(&self, w2id: &mut HashMap, id2w: &mut Vec) { + for token in &self.special_tokens { + if !w2id.contains_key(&token.content) { + id2w.push(token.content.to_owned()); + w2id.insert(token.content.to_owned(), (id2w.len() - 1) as u32); + } + } + } + + /// Compute the initial alphabet and limit it if relevant + fn compute_alphabet( + &self, + wc: &HashMap, + w2id: &mut HashMap, + id2w: &mut Vec, + ) { + // Compute the alphabet from seen words + let mut alphabet: HashMap = HashMap::new(); + for (word, count) in wc { + for c in word.chars() { + alphabet + .entry(c) + .and_modify(|cnt| *cnt += *count as usize) + .or_insert(*count as usize); + } + } + + // Also include anything from the provided initial alphabet + for c in &self.initial_alphabet { + alphabet + .entry(*c) + .and_modify(|cnt| *cnt = usize::MAX) + .or_insert(usize::MAX); + } + + let mut kept = alphabet.iter().collect::>(); + + // Compute the number of chars to remove from the alphabet + // If `limit_alphabet < initial_alphabet.len()`, some of these initial characters + // will be removed + let to_remove = self + .limit_alphabet + .map(|limit| { + if alphabet.len() > limit { + alphabet.len() - limit + } else { + 0 + } + }) + .unwrap_or(0); + + // Remove the unwanted chars + if to_remove > 0 { + kept.sort_unstable_by_key(|k| *k.1); + kept.drain(..to_remove); + } + + // Keep the initial alphabet (sorted for determinism) + kept.sort_unstable_by_key(|k| (*k.0) as u32); + kept.into_iter().for_each(|(c, _)| { + let s = c.to_string(); + if !w2id.contains_key(&s) { + id2w.push(s.clone()); + w2id.insert(s, (id2w.len() - 1) as u32); + } + }); + } + + /// Tokenize words and add subwords to the vocabulary when relevant + fn tokenize_words( + &self, + wc: &HashMap, + w2id: &mut HashMap, + id2w: &mut Vec, + p: &Option, + ) -> (Vec, Vec) { + let mut words: Vec = Vec::with_capacity(wc.len()); + let mut counts: Vec = Vec::with_capacity(wc.len()); + + for (word, count) in wc { + let mut current_word = Word::new(); + counts.push(*count); + + for (is_first, is_last, c) in word.chars().with_first_and_last() { + let mut s = c.to_string(); + if w2id.contains_key(&s) { + // Found the initial char in the authorized alphabet + + // Add the `continuing_subword_prefix` if relevant + if !is_first { + if let Some(prefix) = &self.continuing_subword_prefix { + s = format!("{prefix}{s}"); + } + } + // Add the `end_of_word_suffix` if relevant + if is_last { + if let Some(suffix) = &self.end_of_word_suffix { + s = format!("{s}{suffix}"); + } + } + + // Insert the new formed string if necessary + if !w2id.contains_key(&s) { + id2w.push(s.clone()); + w2id.insert(s.clone(), (id2w.len() - 1) as u32); + } + current_word.add(w2id[&s], 1); // We do not care about the len here + } + } + words.push(current_word); + + if let Some(p) = p { + p.inc(1); + } + } + + (words, counts) + } + + fn count_pairs( + &self, + words: &[Word], + counts: &[u64], + p: &Option, + ) -> (HashMap, HashMap>) { + words + .maybe_par_iter() + .enumerate() + .map(|(i, word)| { + let mut pair_counts = HashMap::new(); + let mut where_to_update: HashMap> = HashMap::new(); + + for window in word.get_chars().windows(2) { + let cur_pair: Pair = (window[0], window[1]); + + // Initialize pair_counts and where_to_update for this pair if we just saw it + if !pair_counts.contains_key(&cur_pair) { + pair_counts.insert(cur_pair, 0); + } + + // Then update counts + let count = counts[i]; + where_to_update + .entry(cur_pair) + .and_modify(|h| { + h.insert(i); + }) + .or_insert_with(|| { + let mut h = HashSet::new(); + h.insert(i); + h + }); + *pair_counts.get_mut(&cur_pair).unwrap() += count as i32; + } + + if let Some(p) = &p { + p.inc(1); + } + + (pair_counts, where_to_update) + }) + .reduce( + || (HashMap::new(), HashMap::new()), + |(mut pair_counts, mut where_to_update), (pc, wtu)| { + for (k, v) in pc { + pair_counts.entry(k).and_modify(|c| *c += v).or_insert(v); + } + for (k, v) in wtu { + where_to_update + .entry(k) + .and_modify(|set| *set = set.union(&v).copied().collect()) + .or_insert(v); + } + (pair_counts, where_to_update) + }, + ) + } + + pub fn do_train( + &self, + word_counts: &HashMap, + model: &mut BacktrackingBpe, + ) -> Result> { + let mut word_to_id: HashMap = HashMap::with_capacity(self.vocab_size); + let mut id_to_word: Vec = Vec::with_capacity(self.vocab_size); + let max_token_length: usize = self.max_token_length.unwrap_or(usize::MAX); + + let progress = self.setup_progress(); + + // + // 1. Add all special tokens to the vocabulary + // + self.add_special_tokens(&mut word_to_id, &mut id_to_word); + + // + // 2. Compute the initial alphabet + // + self.compute_alphabet(word_counts, &mut word_to_id, &mut id_to_word); + + // + // 3. Tokenize words + // + self.update_progress(&progress, word_counts.len(), "Tokenize words"); + let (mut words, counts) = + self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress); + self.finalize_progress(&progress, words.len()); + + // + // 4. Count pairs in words + // + self.update_progress(&progress, words.len(), "Count pairs"); + let (mut pair_counts, mut where_to_update) = self.count_pairs(&words, &counts, &progress); + // Insert them in the queue + let mut queue = BinaryHeap::with_capacity(pair_counts.len()); + where_to_update.drain().for_each(|(pair, pos)| { + let count = pair_counts[&pair]; + if count > 0 { + queue.push(Merge { + pair, + count: count as u64, + pos, + }); + } + }); + self.finalize_progress(&progress, words.len()); + + // + // 5. Do merges + // + self.update_progress(&progress, self.vocab_size, "Compute merges"); + let mut merges: Vec<(Pair, u32)> = vec![]; + loop { + // Stop as soon as we have a big enough vocabulary + if word_to_id.len() >= self.vocab_size { + break; + } + + if queue.is_empty() { + break; + } + + let mut top = queue.pop().unwrap(); + if top.count != pair_counts[&top.pair] as u64 { + top.count = pair_counts[&top.pair] as u64; + queue.push(top); + continue; + } + + if top.count < 1 || self.min_frequency > top.count { + break; + } + + let part_a = &id_to_word[top.pair.0 as usize]; + let mut part_b = id_to_word[top.pair.1 as usize].to_owned(); + + // Build new token + if let Some(prefix) = &self.continuing_subword_prefix { + if part_b.starts_with(prefix) { + let prefix_byte_len = prefix.chars().map(|c| c.len_utf8()).sum(); + part_b = part_b[prefix_byte_len..].to_string(); + } + } + let new_token = format!("{part_a}{part_b}"); + // implement sentencepiece-like merge. + // if this code were to be merged, integrate a way in the python bindings to communicate this variable + // default should be 0/None to maintain previous behavior. 16 is the spm default. + + // Insert new token if it does not already exist + let new_token_id = word_to_id + .get(&new_token) + .copied() + .unwrap_or(id_to_word.len() as u32); + if !word_to_id.contains_key(&new_token) { + id_to_word.push(new_token.clone()); + word_to_id.insert(new_token.clone(), new_token_id); + } + merges.push((top.pair, new_token_id)); + + // Merge the new pair in every words + // Safety: This is just a type assertion, the code below may no longer be safe + // if the type of `pos` changes + let pos: &HashSet = &top.pos; + + let words_len = words.len(); + struct WordPtr(*mut Word); + // Safety: We do not actually use this for concurrent access to the same memory, + // only to different chunks within the same allocation. + unsafe impl Sync for WordPtr {} + let word_start = WordPtr(words.as_mut_ptr()); + + let changes = pos + .maybe_par_iter() + .flat_map(|&i| { + // Safety: + // We are producing a valid pointer since we are indexing in bounds + // + // We can access each `word` here in parallel because each position + // can be there only once (pos is a HashSet). + unsafe { + assert!(i < words_len); + // This is words[i], but avoids needing to go through &T (which triggers UB) + let word = word_start.0.add(i); + // let word: &mut Word = &mut (*word); + (*word) + .merge(top.pair.0, top.pair.1, new_token_id, max_token_length) + .into_iter() + .map(|c| (c, i)) + .collect::>() + } + }) + .collect::>(); + + // Introduce new formed pairs + for ((pair, change), iw) in changes { + let count = change * counts[iw] as i32; + pair_counts + .entry(pair) + .and_modify(|c| *c += count) + .or_insert(count); + if change > 0 { + where_to_update + .entry(pair) + .and_modify(|h| { + h.insert(iw); + }) + .or_insert_with(|| { + let mut h = HashSet::new(); + h.insert(iw); + h + }); + } + } + where_to_update.drain().for_each(|(pair, pos)| { + let count = pair_counts[&pair]; + if count > 0 { + queue.push(Merge { + pair, + count: count as u64, + pos, + }); + } + }); + + if let Some(p) = &progress { + p.inc(1); + } + } + self.finalize_progress(&progress, merges.len()); + + // Transfer new vocab & options to model + model.vocab = word_to_id; + model.vocab_r = model + .vocab + .iter() + .map(|(key, val)| (*val, key.to_owned())) + .collect(); + model.merges = merges + .into_iter() + .enumerate() + .map(|(i, (pair, new_token_id))| (pair, (i as u32, new_token_id))) + .collect(); + + Ok(self.special_tokens.clone()) + } +} + +impl Trainer for BacktrackingBpeTrainer { + type Model = BacktrackingBpe; + + /// Train a BacktrackingBpe model + fn train(&self, model: &mut BacktrackingBpe) -> Result> { + self.do_train(&self.words, model) + } + + /// Whether we should show progress + fn should_show_progress(&self) -> bool { + self.show_progress + } + + fn feed(&mut self, iterator: I, process: F) -> Result<()> + where + I: Iterator + Send, + S: AsRef + Send, + F: Fn(&str) -> Result> + Sync, + { + let words: Result> = iterator + .maybe_par_bridge() + .map(|sequence| { + let words = process(sequence.as_ref())?; + let mut map = HashMap::new(); + for word in words { + map.entry(word).and_modify(|c| *c += 1).or_insert(1); + } + Ok(map) + }) + .reduce( + || Ok(HashMap::new()), + |acc, ws| { + let mut acc = acc?; + for (k, v) in ws? { + acc.entry(k).and_modify(|c| *c += v).or_insert(v); + } + Ok(acc) + }, + ); + + self.words = words?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::{BacktrackingBpe, BacktrackingBpeTrainer, Pair}; + use std::collections::HashMap; + + #[test] + fn test_train() { + let word_counts: HashMap = [ + ("roses".into(), 1), + ("are".into(), 2), + ("red".into(), 1), + ("voilets".into(), 1), + ("blue".into(), 1), + ("BERT".into(), 1), + ("is".into(), 2), + ("big".into(), 1), + ("and".into(), 1), + ("so".into(), 1), + ("GPT-2".into(), 1), + ] + .iter() + .cloned() + .collect(); + let trainer = BacktrackingBpeTrainer::builder() + .show_progress(false) + .min_frequency(2) + .build(); + let mut model = BacktrackingBpe::default(); + trainer.do_train(&word_counts, &mut model).unwrap(); + + // Vocab should contain all of the characters from the `word_counts` mapping + // as well as three merges: 're', 'are', and 'is'. + let expected_vocab: HashMap = [ + ("-".into(), 0), + ("2".into(), 1), + ("B".into(), 2), + ("E".into(), 3), + ("G".into(), 4), + ("P".into(), 5), + ("R".into(), 6), + ("T".into(), 7), + ("a".into(), 8), + ("b".into(), 9), + ("d".into(), 10), + ("e".into(), 11), + ("g".into(), 12), + ("i".into(), 13), + ("l".into(), 14), + ("n".into(), 15), + ("o".into(), 16), + ("r".into(), 17), + ("s".into(), 18), + ("t".into(), 19), + ("u".into(), 20), + ("v".into(), 21), + ("re".into(), 22), + ("are".into(), 23), + ("is".into(), 24), + ] + .iter() + .cloned() + .collect(); + assert_eq!(model.vocab, expected_vocab); + + // The keys in `merges` are pairs of symbols, the values are tuples of (rank, id), + // where 'rank' determines the order in which this merge will be applied during + // tokenization, and 'id' is the vocab id of the symbol resulting from merging + // the pair of symbols in the corresponding key. + let expected_merges: HashMap = [ + ((17, 11), (0, 22)), // 'r' + 'e' -> 're' + ((8, 22), (1, 23)), // 'a' + 're' -> 'are' + ((13, 18), (2, 24)), // 'i' + 's' -> 'is' + ] + .iter() + .cloned() + .collect(); + assert_eq!(model.merges, expected_merges); + } + #[test] + fn bpe_test_max_token_length_16() { + /* bpe_test_max_token_length series of tests test the max_token_length flag of bpetrainer + // this is the more robust version that only tests max length of learned tokens + // (pre) tokenizer settings or vocab can be easily modified when necessary + */ + + let max_token_length = 16; + let long_word_counts: HashMap = [ + ("singlelongtokenwithoutcasechange", 2), + ("singleLongTokenWithCamelCaseChange", 2), + ("Longsingletokenwithpunctu@t!onwithin", 2), + ("Anotherlongsingletokenwithnumberw1th1n", 2), + ("짧은한글문자열짧은한", 2), // korean 10 char + ("긴한글문자열긴한글문자열긴한글문", 2), // korean 16 char + ("短字符串短字符串短字", 2), //simplified chinese 10 char + ("长字符串长字符串长字符串长字符串", 2), // simp. chinese 16 char + ("短い文字列短い文字列", 2), // japanese 10 char + ("長い文字列長い文字列長い文字列長", 2), // japanese 16 char + ("so", 2), + ("GPT-2", 2), + ] + .iter() + .map(|(key, value)| (key.to_string(), *value)) + .collect(); + let trainer = BacktrackingBpeTrainer::builder() + .max_token_length(Some(max_token_length)) + .show_progress(false) + .min_frequency(0) + .build(); + let mut model = BacktrackingBpe::default(); + trainer.do_train(&long_word_counts, &mut model).unwrap(); + let vocab = model.get_vocab(); + for token in vocab.keys() { + assert!( + token.chars().count() <= max_token_length, + "token too long : {} , chars().count() = {}", + token, + token.chars().count() + ) + } + } + #[test] + fn bpe_test_max_token_length_direct_assert() { + /* more direct version of bpe_test_max_token_length test + // directly compares tokens with known expected values. + // maybe unstable depending on specific settings or changes. + */ + let long_word_counts: HashMap = [ + ("sin", 2), + ("Sin", 2), + ("Lon", 2), + ("Ano", 2), + ("짧은한", 2), + ("긴한글", 2), + ("短字符", 2), + ("长字符", 2), + ("短い文", 2), + ("長い文", 2), + ("so", 2), + ("GP", 2), + ] + .iter() + .map(|(key, value)| (key.to_string(), *value)) + .collect(); + let trainer = BacktrackingBpeTrainer::builder() + .max_token_length(Some(2)) + .show_progress(false) + .min_frequency(0) + .build(); + let mut model = BacktrackingBpe::default(); + trainer.do_train(&long_word_counts, &mut model).unwrap(); + let trained_vocab: HashMap = model.get_vocab(); + let expected_vocab: HashMap = [ + ("短", 12), + ("n", 6), + ("i", 5), + ("s", 8), + ("字符", 23), + ("長", 14), + ("긴", 17), + ("い文", 22), + ("L", 2), + ("in", 21), + ("o", 7), + ("은한", 29), + ("S", 4), + ("P", 3), + ("so", 27), + ("符", 13), + ("文", 11), + ("字", 10), + ("짧", 19), + ("GP", 25), + ("글", 16), + ("G", 1), + ("An", 24), + ("长", 15), + ("A", 0), + ("Lo", 26), + ("긴한", 28), + ("い", 9), + ("한", 20), + ("은", 18), + ] + .iter() + .cloned() + .map(|(k, v)| (k.to_string(), v)) + .collect(); + assert_eq!(trained_vocab, expected_vocab) + } +} diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index 6bd51cb1b..97d8e12ad 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -4,7 +4,7 @@ use std::{iter, mem}; mod model; mod serialization; pub mod trainer; -mod word; +pub mod word; pub(crate) type Pair = (u32, u32); @@ -79,4 +79,4 @@ where // Re-export pub use model::*; pub use trainer::*; -use word::*; +pub use word::*; diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 6fc8033e3..32720a126 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -53,7 +53,7 @@ impl Symbol { } #[derive(Clone, Default)] -pub(super) struct Word { +pub struct Word { symbols: Vec, } impl std::fmt::Debug for Word { @@ -74,7 +74,7 @@ impl std::fmt::Debug for Word { } impl Word { - pub(super) fn new() -> Self { + pub(crate) fn new() -> Self { Word { symbols: vec![] } } @@ -84,7 +84,7 @@ impl Word { } } - pub(super) fn add(&mut self, c: u32, byte_len: usize) { + pub(crate) fn add(&mut self, c: u32, byte_len: usize) { let (prev, next) = { let len = self.symbols.len() as isize; if let Some(last) = self.symbols.last_mut() { @@ -103,7 +103,7 @@ impl Word { }); } - pub(super) fn merge( + pub(crate) fn merge( &mut self, c1: u32, c2: u32, @@ -251,7 +251,7 @@ impl Word { self.symbols.retain(|s| s.len != 0); } - pub(super) fn get_chars(&self) -> Vec { + pub(crate) fn get_chars(&self) -> Vec { self.symbols.iter().map(|s| s.c).collect() } diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index da8adddcf..818c8f35f 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -11,7 +11,7 @@ use std::path::{Path, PathBuf}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use crate::models::backtracking_bpe::BacktrackingBpe; +use crate::models::backtracking_bpe::{BacktrackingBpe, BacktrackingBpeTrainer}; use crate::models::bpe::{BpeTrainer, BPE}; use crate::models::unigram::{Unigram, UnigramTrainer}; use crate::models::wordlevel::{WordLevel, WordLevelTrainer}; @@ -89,6 +89,7 @@ impl<'de> Deserialize<'de> for ModelWrapper { WordPiece, WordLevel, Unigram, + BacktrackingBpe, } #[derive(Deserialize)] @@ -101,7 +102,7 @@ impl<'de> Deserialize<'de> for ModelWrapper { #[derive(Deserialize)] #[serde(untagged)] pub enum ModelUntagged { - BPE(BacktrackingBpe), + BPE(BPE), // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility // with the versions not including the "type"), since WordLevel is a subset of WordPiece WordPiece(WordPiece), @@ -124,11 +125,14 @@ impl<'de> Deserialize<'de> for ModelWrapper { EnumType::Unigram => ModelWrapper::Unigram( serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?, ), + EnumType::BacktrackingBpe => ModelWrapper::BacktrackingBpe( + serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?, + ), }, ModelHelper::Legacy(value) => { let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?; match untagged { - ModelUntagged::BPE(bpe) => ModelWrapper::BacktrackingBpe(bpe), + ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe), ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe), ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe), ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe), @@ -241,6 +245,7 @@ pub enum TrainerWrapper { WordPieceTrainer(WordPieceTrainer), WordLevelTrainer(WordLevelTrainer), UnigramTrainer(UnigramTrainer), + BacktrackingBpeTrainer(BacktrackingBpeTrainer), } impl Trainer for TrainerWrapper { @@ -252,6 +257,7 @@ impl Trainer for TrainerWrapper { Self::WordPieceTrainer(wpt) => wpt.should_show_progress(), Self::WordLevelTrainer(wpt) => wpt.should_show_progress(), Self::UnigramTrainer(wpt) => wpt.should_show_progress(), + Self::BacktrackingBpeTrainer(wpt) => wpt.should_show_progress(), } } @@ -273,6 +279,10 @@ impl Trainer for TrainerWrapper { ModelWrapper::Unigram(u) => t.train(u), _ => Err("UnigramTrainer can only train a Unigram".into()), }, + Self::BacktrackingBpeTrainer(t) => match model { + ModelWrapper::BacktrackingBpe(bpe) => t.train(bpe), + _ => Err("BpeTrainer can only train a BPE".into()), + }, } } @@ -287,6 +297,7 @@ impl Trainer for TrainerWrapper { Self::WordPieceTrainer(wpt) => wpt.feed(iterator, process), Self::WordLevelTrainer(wpt) => wpt.feed(iterator, process), Self::UnigramTrainer(wpt) => wpt.feed(iterator, process), + Self::BacktrackingBpeTrainer(wpt) => wpt.feed(iterator, process), } } } @@ -295,6 +306,11 @@ impl_enum_from!(BpeTrainer, TrainerWrapper, BpeTrainer); impl_enum_from!(WordPieceTrainer, TrainerWrapper, WordPieceTrainer); impl_enum_from!(UnigramTrainer, TrainerWrapper, UnigramTrainer); impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer); +impl_enum_from!( + BacktrackingBpeTrainer, + TrainerWrapper, + BacktrackingBpeTrainer +); #[cfg(test)] mod tests { From ee18ba9b5cd511c584d452eba95e54bc9f7c42e4 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 4 Jan 2025 09:42:40 +0100 Subject: [PATCH 08/11] nits --- .../src/models/backtracking_bpe/model.rs | 29 +------------------ .../models/backtracking_bpe/serialization.rs | 28 +++++++++++------- 2 files changed, 19 insertions(+), 38 deletions(-) diff --git a/tokenizers/src/models/backtracking_bpe/model.rs b/tokenizers/src/models/backtracking_bpe/model.rs index 020447b1e..25eef7e21 100644 --- a/tokenizers/src/models/backtracking_bpe/model.rs +++ b/tokenizers/src/models/backtracking_bpe/model.rs @@ -323,7 +323,7 @@ impl BacktrackingBpe { let mut start = 0; while start < bytes.len() { let end = bitfield.successor(start + 1); - let token = self.find_token_by_bytes(&bytes[start..end]).expect(""); + let token = self.find_token_by_bytes(&bytes[start..end]).expect(&format!("Could not convert bytes to tokens for bytes: [{:?}]", bytes.into_iter().map(|b| char::from(*b)).join(","))); encoded.push(token); start = end; } @@ -745,32 +745,6 @@ impl Model for BacktrackingBpe { let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r); let serialized = serde_json::to_string(&order_vocab_iter)?; vocab_file.write_all(serialized.as_bytes())?; - // - // // Write merges.txt - // let merges_file_name = match name { - // Some(name) => format!("{name}-merges.txt"), - // None => "merges.txt".to_string(), - // }; - // - // let merges_path: PathBuf = [folder, Path::new(merges_file_name.as_str())] - // .iter() - // .collect(); - // let mut merges_file = File::create(&merges_path)?; - // let mut merges: Vec<(&Vec<&str, &str>, &u32)> = self - // .merges - // .iter() - // .map(|(pair, (rank, _))| (pair, rank)) - // .collect(); - // merges.sort_unstable_by_key(|k| *k.1); - // merges_file.write_all(b"#version: 0.2\n")?; - // merges_file.write_all( - // &merges - // .into_iter() - // .flat_map(|(pair, _)| { - // format!("{} {}\n", self.vocab_r[&pair.0], self.vocab_r[&pair.1]).into_bytes() - // }) - // .collect::>()[..], - // )?; Ok(vec![vocab_path]) // Ok(vec![vocab_path, merges_path]) } @@ -783,7 +757,6 @@ impl Model for BacktrackingBpe { #[cfg(test)] mod tests { use super::*; - use tempfile::NamedTempFile; #[test] fn my_example() { diff --git a/tokenizers/src/models/backtracking_bpe/serialization.rs b/tokenizers/src/models/backtracking_bpe/serialization.rs index 214d860d1..dabd5f748 100644 --- a/tokenizers/src/models/backtracking_bpe/serialization.rs +++ b/tokenizers/src/models/backtracking_bpe/serialization.rs @@ -14,7 +14,7 @@ impl Serialize for BacktrackingBpe { where S: Serializer, { - let mut model = serializer.serialize_struct("BacktrackingBpe", 8)?; + let mut model = serializer.serialize_struct("BPE", 8)?; // Start by small fields model.serialize_field("type", "BPE")?; @@ -45,7 +45,7 @@ impl<'de> Deserialize<'de> for BacktrackingBpe { D: Deserializer<'de>, { deserializer.deserialize_struct( - "BacktrackingBpe", + "BPE", &["type", "dropout", "unk_token", "vocab", "merges"], BacktrackingBpeVisitor, ) @@ -57,7 +57,7 @@ impl<'de> Visitor<'de> for BacktrackingBpeVisitor { type Value = BacktrackingBpe; fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(fmt, "struct BacktrackingBpe") + write!(fmt, "struct BacktrackingBpe to be the type") } fn visit_map(self, mut map: V) -> std::result::Result @@ -94,7 +94,7 @@ impl<'de> Visitor<'de> for BacktrackingBpeVisitor { u => { return Err(serde::de::Error::invalid_value( serde::de::Unexpected::Str(u), - &"BacktrackingBpe", + &"BacktrackingBpe should have been found", )) } }, @@ -124,7 +124,6 @@ mod test { #[test] fn test_serialization() { let vocab: Vocab = [ - ("".into(), 0), ("a".into(), 1), ("b".into(), 2), ("ab".into(), 3), @@ -138,17 +137,26 @@ mod test { .build() .unwrap(); - let legacy = r#"{"type":"BacktrackingBpe","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#; - let legacy = serde_json::from_str(legacy).unwrap(); - assert_eq!(bpe, legacy); + let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"","fuse_unk":false,"byte_fallback":false,"vocab":{"a":1,"b":2,"ab":3},"merges":["a b"]}"#; + let legacy = serde_json::from_str(legacy); + match legacy { + Ok(_) => { + println!("Good"); + assert_eq!(bpe, legacy.unwrap()); + } + Err(err) => { + println!("Error: {:?}", err); + } + } + let data = serde_json::to_string(&bpe).unwrap(); assert_eq!( data, - r#"{"type":"BacktrackingBpe","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"# + r#"{"type":"BPE","vocab":{"ab":0,"a":1,"b":2},"merges":[["a","b"]]}"# ); let reconstructed = serde_json::from_str(&data).unwrap(); - assert_eq!(bpe, reconstructed); + assert_eq!(bpe, reconstructed); // TODO failing for now! // With a space in the token let vocab: Vocab = [ From 4b63a7af79cffbcbdeaf8a19d60dead2fc710ba3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 6 Jan 2025 11:14:49 +0100 Subject: [PATCH 09/11] update serialization to support initializing from BPE --- .../models/backtracking_bpe/serialization.rs | 50 +++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/tokenizers/src/models/backtracking_bpe/serialization.rs b/tokenizers/src/models/backtracking_bpe/serialization.rs index dabd5f748..f5d015b17 100644 --- a/tokenizers/src/models/backtracking_bpe/serialization.rs +++ b/tokenizers/src/models/backtracking_bpe/serialization.rs @@ -2,6 +2,7 @@ use super::{ super::bpe::Pair, super::OrderedVocabIter, convert_merges_to_hashmap, BacktrackingBpe, BacktrackingBpeBuilder, }; +use regex_syntax::ast::print; use serde::{ de::{Error, MapAccess, Visitor}, ser::SerializeStruct, @@ -86,11 +87,11 @@ impl<'de> Visitor<'de> for BacktrackingBpeVisitor { builder = builder.unk_token(unk); } } - "vocab" => vocab = Some(map.next_value()?), + "vocab" => vocab = Some(map.next_value()?), "merges" => merges = Some(map.next_value()?), "type" => match map.next_value()? { "BacktrackingBpe" => {} - "BPE" => {} + "BPE" => {println!("Type is BPE but initializing a backtracking BPE")} u => { return Err(serde::de::Error::invalid_value( serde::de::Unexpected::Str(u), @@ -98,14 +99,18 @@ impl<'de> Visitor<'de> for BacktrackingBpeVisitor { )) } }, - _ => {} + field => { + println!("Ignoring unused field {:?}", field); // TODO make it into a logger + // Ensure the value is consumed to maintain valid deserialization + let _ = map.next_value::()?; + } } } if let (Some(vocab), Some(merges)) = (vocab, merges) { let merges = match merges { MergeType::Tuple(merges) => merges, MergeType::Legacy(merges) => { - convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)? + convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(|e| Error::custom("Error in convert merges to hashmap"))? } }; builder = builder.vocab_and_merges(vocab, merges); @@ -123,6 +128,29 @@ mod test { #[test] fn test_serialization() { + let bpe_string = r#"{ + "type": "BPE", + "dropout": null, + "unk_token": "", + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": true, + "vocab": { + "a": 1, + "b c d": 2, + "ab c d": 3 + }, + "merges": [ + ["a", "b c d"] + ] + }"#; + let reconstructed: Result = serde_json::from_str(&bpe_string); + println!("End of my example"); + + + let vocab: Vocab = [ ("a".into(), 1), ("b".into(), 2), @@ -137,6 +165,17 @@ mod test { .build() .unwrap(); + match reconstructed { + Ok(reconstructed) => { + println!("Good"); + assert_eq!(bpe, reconstructed); + } + Err(err) => { + println!("Error deserializing: {:?}", err); + + } + } + let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"","fuse_unk":false,"byte_fallback":false,"vocab":{"a":1,"b":2,"ab":3},"merges":["a b"]}"#; let legacy = serde_json::from_str(legacy); match legacy { @@ -180,5 +219,8 @@ mod test { ); let reconstructed = serde_json::from_str(&data).unwrap(); assert_eq!(bpe, reconstructed); + + + } } From a7baf1b3218c0d37de3c5728c1b40218a7a19754 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 6 Jan 2025 11:25:33 +0100 Subject: [PATCH 10/11] nits --- tokenizers/src/models/backtracking_bpe/serialization.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/src/models/backtracking_bpe/serialization.rs b/tokenizers/src/models/backtracking_bpe/serialization.rs index f5d015b17..87a068be9 100644 --- a/tokenizers/src/models/backtracking_bpe/serialization.rs +++ b/tokenizers/src/models/backtracking_bpe/serialization.rs @@ -114,7 +114,7 @@ impl<'de> Visitor<'de> for BacktrackingBpeVisitor { } }; builder = builder.vocab_and_merges(vocab, merges); - Ok(builder.build().map_err(Error::custom)?) + Ok(builder.build().map_err(|e| Error::custom(format!("Error building the backtraciing BPE {:?}", e)))?) } else { Err(Error::custom("Missing vocab/merges")) } From 224f432ef2f9c6da2d59bf3af6d47ef7749b5c88 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 21 Jan 2025 11:12:07 +0100 Subject: [PATCH 11/11] current state --- bindings/python/src/models.rs | 6 +- .../src/models/backtracking_bpe/model.rs | 82 +++++++++++-------- .../models/backtracking_bpe/serialization.rs | 6 +- 3 files changed, 57 insertions(+), 37 deletions(-) diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 863c69f18..a307f04cb 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -36,7 +36,7 @@ impl PyModel { let base = self.clone(); Ok(match *self.model.as_ref().read().unwrap() { ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base))?.into_py(py), - ModelWrapper::BacktrackBpe(_) => Py::new(py, (PyBacktrackBPE {}, base))?.into_py(py), + ModelWrapper::BacktrackingBpe(_) => Py::new(py, (PyBacktrackingBpe {}, base))?.into_py(py), ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))?.into_py(py), ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base))?.into_py(py), ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base))?.into_py(py), @@ -562,10 +562,10 @@ impl PyBPE { } #[pyclass(module = "bpe")] -struct PyBacktrackBpe {} +struct PyBacktrackingBpe {} #[pymethods] -impl PyBacktrackBpe { +impl PyBacktrackingBpe { #[getter] fn get_dropout(self_: PyRef) -> Option { getter!(self_, BPE, dropout) diff --git a/tokenizers/src/models/backtracking_bpe/model.rs b/tokenizers/src/models/backtracking_bpe/model.rs index 25eef7e21..8b1ee122a 100644 --- a/tokenizers/src/models/backtracking_bpe/model.rs +++ b/tokenizers/src/models/backtracking_bpe/model.rs @@ -181,7 +181,7 @@ impl BacktrackingBpeBuilder { } let backtraching_bpe = BacktrackingBpe::from_dictionary( - self.config.vocab.into_iter().map(|(k, v)| k.into_bytes()), + self.config.vocab.into_iter().sorted_unstable_by(|a,b| a.1.cmp(&b.1)).map(|(k, v)| k.into_bytes()), Some(self.config.merges), None, ); @@ -225,6 +225,8 @@ fn is_valid_token_pair( } } // Reverse the merge operation from BPE. + println!("{token1}, {token2}"); + println!("{:?}", split_table); if token1 > token2 { limit = token1; token1 = unsafe { split_table.get_unchecked(token1 as usize).1 }; @@ -323,7 +325,8 @@ impl BacktrackingBpe { let mut start = 0; while start < bytes.len() { let end = bitfield.successor(start + 1); - let token = self.find_token_by_bytes(&bytes[start..end]).expect(&format!("Could not convert bytes to tokens for bytes: [{:?}]", bytes.into_iter().map(|b| char::from(*b)).join(","))); + println!("bitfield's successor {:?}", &bytes[start..end]); + let token = self.find_token_by_bytes(&bytes[start..end]).expect(&format!("Could not convert bytes to tokens for bytes: [{:?}]", bytes.into_iter().map(|b| char::from(*b)).join("|"))); encoded.push(token); start = end; } @@ -395,6 +398,7 @@ impl BacktrackingBpe { let mut bytes_hash_to_token = FnvHashMap::default(); let mut merge_map: HashMap = HashMap::new(); for (i, token) in tokens.into_iter().enumerate() { + println!("token byte: {:?}, {i}", token); bytes_hash_to_token.insert(hash_bytes(&token, hash_factor), i as u32); all_tokens_rev.extend(token.iter().copied().rev()); all_tokens.extend(token); @@ -421,7 +425,7 @@ impl BacktrackingBpe { .enumerate() .map(|(id, item)| { ( - unsafe { String::from_utf8_unchecked(Vec::from(item.clone())) }, + unsafe { String::from_utf8_unchecked(Vec::from(item)) }, id as u32, ) }) @@ -431,7 +435,7 @@ impl BacktrackingBpe { .enumerate() .map(|(id, item)| { (id as u32, unsafe { - String::from_utf8_unchecked(Vec::from(item.clone())) + String::from_utf8_unchecked(Vec::from(item)) }) }) .collect(); @@ -440,40 +444,53 @@ impl BacktrackingBpe { let mut pair_lookup = FnvHashMap::default(); if let Some(ref merges) = merges { - for (id, pair) in merges.into_iter().enumerate() { - let token1 = vocab[&pair.0.clone()]; - let token2 = vocab[&pair.1.clone()]; - pair_lookup.insert((token1, token2), id as u32); - split_table.push((token1, token2)); - merge_map.insert(Pair::from((token1, token2)), (id as u32, id as u32)); + for (index, pair) in merges.into_iter().enumerate() { + let token1 = &pair.0.clone(); + let token2 = &pair.1.clone(); + let id1 = vocab[token1]; + let id2 = vocab[token2]; + let new_token = format!("{}{}", token1, &token2); + let new_id = vocab + .get(&new_token) + .ok_or(Error::MergeTokenOutOfVocabulary(new_token)); + if let Ok(id) = new_id { + println!("{token1}, {token2}, {id1}, {id2}, {id}"); + pair_lookup.insert((id1, id2), *id); + split_table.push((id1, id2)); + merge_map.insert(Pair::from((id1, id2)), (index as u32, *id )); + }else{ + // gracefully error out + } + // TODO wrong } + split_table.push((merges.len() as u32, merges.len() as u32)); } else { // Reverse engineer the merge/split table. for (id, token) in token_iter(&all_tokens, &token_starts).enumerate() { - let mut token1 = next_prefix_match[id]; - while token1 != u32::MAX { - let rest = &token[token_range(&token_starts, token1).len()..]; - if let Some(token2) = find_token_by_bytes( + let mut id1 = next_prefix_match[id]; + while id1 != u32::MAX { + let rest = &token[token_range(&token_starts, id1).len()..]; + if let Some(id2) = find_token_by_bytes( &all_tokens, &token_starts, &bytes_hash_to_token, rest, hash_factor, ) { - if token1 < id as u32 - && token2 < id as u32 - && is_valid_token_pair(&pair_lookup, &split_table, token1, token2) + if id1 < id as u32 + && id2 < id as u32 + && is_valid_token_pair(&pair_lookup, &split_table, id1, id2) { - pair_lookup.insert((token1, token2), id as u32); - split_table.push((token1, token2)); - merge_map.insert(Pair::from((token1, token2)), (id as u32, id as u32)); + pair_lookup.insert((id1, id2), id as u32); + split_table.push((id1, id2)); + merge_map.insert(Pair::from((id1, id2)), (id as u32, id as u32)); break; } } - token1 = next_prefix_match[token1 as usize]; + id1 = next_prefix_match[id1 as usize]; } - if token1 == u32::MAX { + if id1 == u32::MAX { split_table.push((id as u32, id as u32)); } } @@ -495,16 +512,17 @@ impl BacktrackingBpe { vocab_r, merges: merge_map, }; - for token_id in 0..bpe.num_tokens() as u32 { - let bytes = bpe.token_bytes(token_id); - let strs = bytes.iter().map(|b| char::from(*b)).collect::>(); - let tokens = bpe.encode_via_bitfield(bytes); - assert_eq!( - tokens, - vec![token_id], - "token {token_id} with bytes {bytes:?} (tokens {strs:?} encodes to {tokens:?} instead of to itself" - ); - } + // for token_id in 0..bpe.num_tokens() as u32 { + // let bytes = bpe.token_bytes(token_id); + // let strs = bytes.iter().map(|b| char::from(*b)).collect::>(); + // println!("Encoding {bytes:?} into bitfield"); + // let tokens = bpe.encode_via_bitfield(bytes); + // assert_eq!( + // tokens, + // vec![token_id], + // "token {token_id} with bytes {bytes:?} (tokens {strs:?} encodes to {tokens:?} instead of to itself" + // ); + // } bpe } diff --git a/tokenizers/src/models/backtracking_bpe/serialization.rs b/tokenizers/src/models/backtracking_bpe/serialization.rs index 87a068be9..dce6fb5ea 100644 --- a/tokenizers/src/models/backtracking_bpe/serialization.rs +++ b/tokenizers/src/models/backtracking_bpe/serialization.rs @@ -114,7 +114,9 @@ impl<'de> Visitor<'de> for BacktrackingBpeVisitor { } }; builder = builder.vocab_and_merges(vocab, merges); - Ok(builder.build().map_err(|e| Error::custom(format!("Error building the backtraciing BPE {:?}", e)))?) + let model = builder.build().map_err(|e| Error::custom(format!("Error building the backtraciing BPE {:?}", e)))?; + println!("{:?}", model); + Ok(model) } else { Err(Error::custom("Missing vocab/merges")) } @@ -150,7 +152,6 @@ mod test { println!("End of my example"); - let vocab: Vocab = [ ("a".into(), 1), ("b".into(), 2), @@ -168,6 +169,7 @@ mod test { match reconstructed { Ok(reconstructed) => { println!("Good"); + println!("{:?}", reconstructed.encode_via_backtracking(b"aab c d")); assert_eq!(bpe, reconstructed); } Err(err) => {