diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 0d5c0ddcd..a307f04cb 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -36,6 +36,7 @@ impl PyModel { let base = self.clone(); Ok(match *self.model.as_ref().read().unwrap() { ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base))?.into_py(py), + ModelWrapper::BacktrackingBpe(_) => Py::new(py, (PyBacktrackingBpe {}, base))?.into_py(py), ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))?.into_py(py), ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base))?.into_py(py), ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base))?.into_py(py), @@ -560,6 +561,145 @@ impl PyBPE { } } +#[pyclass(module = "bpe")] +struct PyBacktrackingBpe {} + +#[pymethods] +impl PyBacktrackingBpe { + #[getter] + fn get_dropout(self_: PyRef) -> Option { + getter!(self_, BPE, dropout) + } + + #[setter] + fn set_dropout(self_: PyRef, dropout: Option) { + setter!(self_, BPE, dropout, dropout); + } + + #[getter] + fn get_unk_token(self_: PyRef) -> Option { + getter!(self_, BPE, unk_token.clone()) + } + + #[setter] + fn set_unk_token(self_: PyRef, unk_token: Option) { + setter!(self_, BPE, unk_token, unk_token); + } + #[new] + #[pyo3( + signature = (vocab=None, merges=None, **kwargs), + text_signature = "(self, vocab=None, merges=None, dropout=None, unk_token=None)")] + fn new( + py: Python<'_>, + vocab: Option, + merges: Option, + kwargs: Option<&Bound<'_, PyDict>>, + ) -> PyResult<(Self, PyModel)> { + if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) { + return Err(exceptions::PyValueError::new_err( + "`vocab` and `merges` must be both specified", + )); + } + + let mut builder = BPE::builder(); + if let (Some(vocab), Some(merges)) = (vocab, merges) { + match (vocab, merges) { + (PyVocab::Vocab(vocab), PyMerges::Merges(merges)) => { + builder = builder.vocab_and_merges(vocab, merges); + } + (PyVocab::Filename(vocab_filename), PyMerges::Filename(merges_filename)) => { + deprecation_warning( + py, + "0.9.0", + "BPE.__init__ will not create from files anymore, try `BPE.from_file` instead", + )?; + builder = + builder.files(vocab_filename.to_string(), merges_filename.to_string()); + } + _ => { + return Err(exceptions::PyValueError::new_err( + "`vocab` and `merges` must be both be from memory or both filenames", + )); + } + } + } + + PyBPE::with_builder(builder, kwargs) + } + + /// Read a :obj:`vocab.json` and a :obj:`merges.txt` files + /// + /// This method provides a way to read and parse the content of these files, + /// returning the relevant data structures. If you want to instantiate some BPE models + /// from memory, this method gives you the expected input from the standard files. + /// + /// Args: + /// vocab (:obj:`str`): + /// The path to a :obj:`vocab.json` file + /// + /// merges (:obj:`str`): + /// The path to a :obj:`merges.txt` file + /// + /// Returns: + /// A :obj:`Tuple` with the vocab and the merges: + /// The vocabulary and merges loaded into memory + #[staticmethod] + #[pyo3(text_signature = "(self, vocab, merges)")] + fn read_file(vocab: &str, merges: &str) -> PyResult<(Vocab, Merges)> { + BPE::read_file(vocab, merges).map_err(|e| { + exceptions::PyException::new_err(format!( + "Error while reading vocab & merges files: {}", + e + )) + }) + } + + /// Instantiate a BPE model from the given files. + /// + /// This method is roughly equivalent to doing:: + /// + /// vocab, merges = BPE.read_file(vocab_filename, merges_filename) + /// bpe = BPE(vocab, merges) + /// + /// If you don't need to keep the :obj:`vocab, merges` values lying around, + /// this method is more optimized than manually calling + /// :meth:`~tokenizers.models.BPE.read_file` to initialize a :class:`~tokenizers.models.BPE` + /// + /// Args: + /// vocab (:obj:`str`): + /// The path to a :obj:`vocab.json` file + /// + /// merges (:obj:`str`): + /// The path to a :obj:`merges.txt` file + /// + /// Returns: + /// :class:`~tokenizers.models.BPE`: An instance of BPE loaded from these files + #[classmethod] + #[pyo3(signature = (vocab, merges, **kwargs))] + #[pyo3(text_signature = "(cls, vocab, merge, **kwargs)")] + fn from_file( + _cls: &Bound<'_, PyType>, + py: Python, + vocab: &str, + merges: &str, + kwargs: Option<&Bound<'_, PyDict>>, + ) -> PyResult> { + let (vocab, merges) = BPE::read_file(vocab, merges).map_err(|e| { + exceptions::PyException::new_err(format!("Error while reading BPE files: {}", e)) + })?; + Py::new( + py, + PyBPE::new( + py, + Some(PyVocab::Vocab(vocab)), + Some(PyMerges::Merges(merges)), + kwargs, + )?, + ) + } +} + + /// An implementation of the WordPiece algorithm /// /// Args: diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 16285e113..ddbeba4d1 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -68,6 +68,8 @@ fancy-regex = { version = "0.13", optional = true} getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" +fnv = "1.0.7" +aneubeck-daachorse = "1.1.1" [features] default = ["progressbar", "onig", "esaxx_fast"] diff --git a/tokenizers/benches/llama3.rs b/tokenizers/benches/llama3.rs index 77af3bd63..f7104d757 100644 --- a/tokenizers/benches/llama3.rs +++ b/tokenizers/benches/llama3.rs @@ -2,26 +2,81 @@ extern crate criterion; use criterion::{Criterion, Throughput}; +use itertools::Itertools; +use tokenizers::models::backtracking_bpe; +use tokenizers::PreTokenizerWrapper; use tokenizers::Tokenizer; pub fn llama3(c: &mut Criterion) { let data = std::fs::read_to_string("data/big.txt").unwrap(); let mut group = c.benchmark_group("llama3-encode"); group.throughput(Throughput::Bytes(data.bytes().len() as u64)); - group.bench_function("llama3-offsets", |b| { - let tokenizer = - Tokenizer::from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", None).unwrap(); + + group.bench_function("llama3-backtracking", |b| { + let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); + let mut vocab = &mut tokenizer + .get_vocab(false) + .clone() + .into_iter() + .collect::>(); // Convert HashMap into a Vec of (String, u32) tuples + // + vocab.sort_by(|a, b| a.1.cmp(&b.1)); + vocab.truncate(vocab.len().saturating_sub(3)); + let vocab: Vec<_> = vocab // Sort by u32 value + .into_iter() // IntoIterator to get the iterator of Vec + .map(|(tok, _)| Vec::from(tok.as_bytes())) + .collect(); + let model: backtracking_bpe::BacktrackingBpe = + backtracking_bpe::BacktrackingBpe::from_dictionary(vocab, None, None); + tokenizer.with_model(model); + let data: Vec<_> = data.lines().collect(); + let add_special_tokens = false; + b.iter(|| { + tokenizer + .encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens) + .unwrap() + }) + }); + + group.bench_function("llama3-backtracking-no-pretok", |b| { + let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); + let mut vocab = &mut tokenizer + .get_vocab(false) + .clone() + .into_iter() + .collect::>(); // Convert HashMap into a Vec of (String, u32) tuples + // + vocab.sort_by(|a, b| a.1.cmp(&b.1)); + vocab.truncate(vocab.len().saturating_sub(3)); + let vocab: Vec<_> = vocab // Sort by u32 value + .into_iter() // IntoIterator to get the iterator of Vec + .map(|(tok, _)| Vec::from(tok.as_bytes())) + .collect(); + let model: backtracking_bpe::BacktrackingBpe = + backtracking_bpe::BacktrackingBpe::from_dictionary(vocab, None, None); + tokenizer.with_model(model); + tokenizer.with_pre_tokenizer(None::); let data: Vec<_> = data.lines().collect(); let add_special_tokens = false; b.iter(|| { tokenizer - .encode_batch_char_offsets(criterion::black_box(data.clone()), add_special_tokens) + .encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens) .unwrap() }) }); - group.bench_function("llama3-nooffsets", |b| { - let tokenizer = - Tokenizer::from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", None).unwrap(); + + group.bench_function("llama3-encode_batch_fast", |b| { + let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); + let data: Vec<_> = data.lines().collect(); + let add_special_tokens = false; + b.iter(|| { + tokenizer + .encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens) + .unwrap() + }) + }); + group.bench_function("llama3-encode_batch", |b| { + let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); let data: Vec<_> = data.lines().collect(); let add_special_tokens = false; b.iter(|| { @@ -30,13 +85,14 @@ pub fn llama3(c: &mut Criterion) { .unwrap() }) }); + group.finish(); } criterion_group! { - name = bert_benches; + name = llama; config = Criterion::default().sample_size(10); targets = llama3 } -criterion_main!(bert_benches); +criterion_main!(llama); diff --git a/tokenizers/src/models/backtracking_bpe/backtracking_state.rs b/tokenizers/src/models/backtracking_bpe/backtracking_state.rs new file mode 100644 index 000000000..98fe15381 --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/backtracking_state.rs @@ -0,0 +1,49 @@ +use super::bitfield::BitField; + +/// This can be thought of as a lazy variation of the dynamic programming approach. +/// It only computes those states which have to be visited in order to compute the tokenization +/// for a given input text. +/// It keeps track of visited states in a bitfield and only remembers the tokenization +/// of the currently processed dynamic programming state. +/// +/// The biggest downside of this approach is that the search for the longest leftmost match (the firt token?) +/// has to be reset at every (backtracking) step which is still a net win in practice compared to other approaches. +#[derive(Clone, PartialEq)] +pub struct BacktrackState<'a> { + pub(crate) text: &'a [u8], + pub(crate) tokens: Vec, // len of the tezt / 3 + pub(crate) next_token: Option, // bpe.next_match(text) wich is longest_searcher.leftmost_find_iter(text)'s first match value + pub(crate) pos: usize, // current pos in the text? + pub(crate) bitfield: BitField, // keeps track of token boundaries? keeps track of all the valid tokenization positions and making the runtime linear in the input length. +} + +impl<'a> BacktrackState<'a> { + pub(crate) fn new(text: &'a [u8], next_token: Option) -> Self { + Self::with_capacity(text, next_token, text.len() / 3) + } + + pub(crate) fn with_capacity(text: &'a [u8], next_token: Option, cap: usize) -> Self { + Self { + text, + tokens: Vec::with_capacity(cap), + next_token, + pos: 0, + bitfield: BitField::new(text.len() + 1), + } + } + pub(crate) fn count(&self) -> usize { + self.tokens.len() + } + + pub(crate) fn pos(&self) -> usize { + self.pos + } + + pub(crate) fn last_token(&self) -> Option { + self.tokens.last().copied() + } + + pub(crate) fn into_tokens(self) -> Vec { + self.tokens + } +} diff --git a/tokenizers/src/models/backtracking_bpe/bitfield.rs b/tokenizers/src/models/backtracking_bpe/bitfield.rs new file mode 100644 index 000000000..832965931 --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/bitfield.rs @@ -0,0 +1,57 @@ +/// Small helper to manage a bit field which supports predecessor and successor queries with a simple scan implementation. +/// This is sufficient for our use case, since two one bits will be at most 128 bits apart. +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct BitField { + bitfield: Vec, +} + +impl BitField { + /// All bits are initialized to 1. + pub(crate) fn new(bits: usize) -> Self { + Self { + bitfield: vec![u64::MAX; (bits + 63) / 64], + } + } + + pub(crate) fn is_set(&self, bit: usize) -> bool { + let (word, bit) = (bit / 64, bit % 64); + self.bitfield[word] & (1 << bit) != 0 + } + + pub(crate) fn clear(&mut self, bit: usize) { + let (word, bit) = (bit / 64, bit % 64); + self.bitfield[word] &= !(1 << bit); + } + + pub(crate) fn successor(&self, bit: usize) -> usize { + let (mut word_idx, bit_idx) = (bit / 64, bit % 64); + let word = self.bitfield[word_idx] >> bit_idx; + if word != 0 { + word.trailing_zeros() as usize + bit + } else { + loop { + word_idx += 1; + let word = self.bitfield[word_idx]; + if word != 0 { + break word.trailing_zeros() as usize + word_idx * 64; + } + } + } + } + + pub(crate) fn predecessor(&self, bit: usize) -> usize { + let (mut word_idx, bit_idx) = (bit / 64, bit % 64); + let word = self.bitfield[word_idx] << (63 - bit_idx); + if word != 0 { + bit - word.leading_zeros() as usize + } else { + loop { + word_idx -= 1; + let word = self.bitfield[word_idx]; + if word != 0 { + break word_idx * 64 + 63 - word.leading_zeros() as usize; + } + } + } + } +} diff --git a/tokenizers/src/models/backtracking_bpe/mod.rs b/tokenizers/src/models/backtracking_bpe/mod.rs new file mode 100644 index 000000000..16f75787e --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/mod.rs @@ -0,0 +1,8 @@ +mod backtracking_state; +mod bitfield; +mod model; +mod serialization; +pub mod trainer; + +pub use model::*; +pub use trainer::*; diff --git a/tokenizers/src/models/backtracking_bpe/model.rs b/tokenizers/src/models/backtracking_bpe/model.rs new file mode 100644 index 000000000..8b1ee122a --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/model.rs @@ -0,0 +1,842 @@ +use super::bitfield::BitField; +use super::{super::bpe::trainer::BpeTrainer, super::bpe::Error, super::OrderedVocabIter}; +use crate::models::bpe::{MergeMap, Pair, BPE}; +use crate::tokenizer::{Model, Result, Token}; +use crate::utils::iter::ResultShunt; +use aneubeck_daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder}; +use fnv::{FnvHashMap, FnvHasher}; +use itertools::Itertools; +use serde::de::Visitor; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_json::Value; +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::hash::{Hash, Hasher}; +use std::ops::Range; +use std::{ + collections::HashMap, + fs::File, + io::prelude::*, + io::{BufRead, BufReader}, + path::{Path, PathBuf}, +}; +pub type Vocab = HashMap; +type VocabR = HashMap; +pub type Merges = Vec<(String, String)>; + +use super::backtracking_state::BacktrackState; + +struct Config { + files: Option<(String, String)>, + vocab: Vocab, + merges: Merges, + dropout: Option, + unk_token: Option, + fuse_unk: bool, + byte_fallback: bool, +} + +pub struct BacktrackingBpeBuilder { + config: Config, +} + +impl Default for BacktrackingBpeBuilder { + fn default() -> Self { + Self { + config: Config { + files: None, + vocab: HashMap::new(), + merges: vec![], + dropout: None, + unk_token: None, + fuse_unk: false, + byte_fallback: false, + }, + } + } +} + +/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. +#[derive(PartialEq, Clone)] +pub struct BacktrackingBpe { + /// All the decoded tokens concatenated into? used to build the aho corasick searchers + all_tokens: Vec, + /// Start index of each token in all_tokens. + /// The end is simply the next entry in this vector. + token_starts: Vec, + /// Mapping from hash of token to token id. + bytes_hash_to_token: FnvHashMap, + /// The two tokens from which the token got merged. + /// If the token is an original one, than the two tokens point back to itself. + split_table: Vec<(u32, u32)>, + /// Mapping from a pair of tokens to a merged token if such a merged token exists. + pair_lookup: FnvHashMap<(u32, u32), u32>, + /// An aho corasick automaton to find the next longest token in a byte sequence. + // #[serde( + // serialize_with = "serialize_daac", + // deserialize_with = "deserialize_daac" + // )] + longest_searcher: DoubleArrayAhoCorasick, + /// An aho corasick automaton to find ALL tokens in a byte sequence. + // #[serde( + // serialize_with = "serialize_daac", + // deserialize_with = "deserialize_daac" + // )] + pub(crate) overlapping_searcher: DoubleArrayAhoCorasick, + /// An aho corasick automaton to find ALL tokens in a byte sequence which is being processed in reverse order. + // #[serde( + // serialize_with = "serialize_daac", + // deserialize_with = "deserialize_daac" + // )] + pub(crate) overlapping_searcher_rev: DoubleArrayAhoCorasick, + /// Mapping from a token to the next longest prefix token. + /// This is in principle information represented by the AhoCorasick automaton. + /// But we don't have efficient access to it and therefore store it here again. + /// If there is none, then the value is set to u32::MAX. + next_prefix_match: Vec, + /// Hash factor used to prevent hash collisions. + hash_factor: u64, + pub vocab: Vocab, + pub vocab_r: VocabR, + unk_token: Option, + pub merges: MergeMap, +} + +use std::fmt; + +// Manually implement the Debug trait to exclude the `cache` field +impl fmt::Debug for BacktrackingBpe { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("BacktrackingBpe") + .field("vocab", &self.vocab) + .field("vocab_r", &self.vocab_r) + // Skipping `cache` field here, it won't be included in debug output + .finish() + } +} + +impl BacktrackingBpeBuilder { + /// Constructs a new `BacktrackingBpeBuilder`. + pub fn new() -> Self { + Self::default() + } + + /// Set the input files. + #[must_use] + pub fn files(mut self, vocab: String, merges: String) -> Self { + self.config.files = Some((vocab, merges)); + self + } + + /// Set the vocab (token -> ID) and merges mappings. + #[must_use] + pub fn vocab_and_merges(mut self, vocab: Vocab, merges: Merges) -> Self { + self.config.vocab = vocab; + self.config.merges = merges; + self + } + + /// Use [dropout](https://arxiv.org/abs/1910.13267) with the model. + #[must_use] + pub fn dropout(mut self, dropout: f32) -> Self { + self.config.dropout = Some(dropout); + self + } + + /// Set the `UNK` token for the vocab. + #[must_use] + pub fn unk_token(mut self, unk_token: String) -> Self { + self.config.unk_token = Some(unk_token); + self + } + + /// Set the `fuse_unk` option. + #[must_use] + pub fn fuse_unk(mut self, fuse_unk: bool) -> Self { + self.config.fuse_unk = fuse_unk; + self + } + + /// Set the `byte_fallback` option. + #[must_use] + pub fn byte_fallback(mut self, byte_fallback: bool) -> Self { + self.config.byte_fallback = byte_fallback; + self + } + + /// Returns a `BacktrackingBpe` model that uses the `BacktrackingBpeBuilder`'s configuration. + pub fn build(mut self) -> Result { + // Validate dropout. + if let Some(p) = self.config.dropout { + if !(0.0..=1.0).contains(&p) { + return Err(Error::InvalidDropout.into()); + } + } + + // Read files if necessary + if let Some((vocab, merges)) = self.config.files { + let (v, m) = BPE::read_file(&vocab, &merges)?; + self.config.vocab = v; + self.config.merges = m; + } + + let backtraching_bpe = BacktrackingBpe::from_dictionary( + self.config.vocab.into_iter().sorted_unstable_by(|a,b| a.1.cmp(&b.1)).map(|(k, v)| k.into_bytes()), + Some(self.config.merges), + None, + ); + Ok(backtraching_bpe) + } +} + +impl Default for BacktrackingBpe { + fn default() -> Self { + Self::builder().build().unwrap() + } +} + +fn token_iter<'a>(all_tokens: &'a [u8], token_starts: &'a [u32]) -> impl Iterator { + token_starts + .iter() + .tuple_windows() + .map(move |(start, end)| &all_tokens[*start as usize..*end as usize]) +} + +fn next_match(longest_searcher: &DoubleArrayAhoCorasick, text: &[u8]) -> Option { + longest_searcher + .leftmost_find_iter(text) + .map(|m| m.value()) + .next() +} + +fn is_valid_token_pair( + pair_lookup: &FnvHashMap<(u32, u32), u32>, + split_table: &[(u32, u32)], + mut token1: u32, + mut token2: u32, +) -> bool { + // Keep track of the maximum token which can still be chosen across the split point. + let mut limit = u32::MAX; + loop { + // Check whether BPE would choose a different token pair across the split point. + if let Some(combined) = pair_lookup.get(&(token1, token2)) { + if *combined < limit { + return false; + } + } + // Reverse the merge operation from BPE. + println!("{token1}, {token2}"); + println!("{:?}", split_table); + if token1 > token2 { + limit = token1; + token1 = unsafe { split_table.get_unchecked(token1 as usize).1 }; + if token1 == limit { + limit = token2 + 1; + token2 = unsafe { split_table.get_unchecked(token2 as usize).0 }; + if token2 + 1 == limit { + return true; + } + } + } else { + limit = token2 + 1; + token2 = unsafe { split_table.get_unchecked(token2 as usize).0 }; + if token2 + 1 == limit { + limit = token1; + token1 = unsafe { split_table.get_unchecked(token1 as usize).1 }; + if token1 == limit { + return true; + } + } + } + } +} + +fn token_range(token_starts: &[u32], token_id: u32) -> Range { + unsafe { + *token_starts.get_unchecked(token_id as usize) as usize + ..*token_starts.get_unchecked(token_id as usize + 1) as usize + } +} + +fn token_bytes<'a>(all_tokens: &'a [u8], token_starts: &[u32], token_id: u32) -> &'a [u8] { + &all_tokens[token_range(token_starts, token_id)] +} + +fn hash_bytes(bytes: &[u8], factor: u64) -> u32 { + let mut hasher = FnvHasher::default(); + bytes.hash(&mut hasher); + // Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash. + // To make them unique for the given tokens, we have to add unfortunately another multiplication. + ((hasher.finish().wrapping_mul(factor)) >> 32) as u32 +} +fn find_token_by_bytes( + all_tokens: &[u8], + token_starts: &[u32], + bytes_hash_to_token: &FnvHashMap, + bytes: &[u8], + hash_factor: u64, +) -> Option { + let hash = hash_bytes(bytes, hash_factor); + let token = *bytes_hash_to_token.get(&hash)?; + if token_bytes(all_tokens, token_starts, token) == bytes { + Some(token) + } else { + None + } +} + +/// Converts the merges strings (for example from `merges.txt` file) with the format +/// "{pair_a} {pair_b}" into the format expected by the BacktrackingBpe struct +pub(crate) fn convert_merges_to_hashmap>( + iter: I, + _vocab: &Vocab, +) -> Result { + let mut merges = vec![]; + + let lines = iter.filter(|l| !l.starts_with("#version")); + for (rank, line) in lines.enumerate() { + let parts = line.split(' ').collect::>(); + if parts.len() != 2 { + return Err(Error::BadMerges(rank + 1).into()); + } + + merges.push((parts[0].to_string(), parts[1].to_string())); + } + + Ok(merges) +} + +impl BacktrackingBpe { + /// Initialize a `BacktrackingBpeBuilder`. + pub fn builder() -> BacktrackingBpeBuilder { + BacktrackingBpeBuilder::new() + } + + /// Create a new BacktrackingBpe model with the given vocab and merges. + pub fn new(vocab: Vocab, merges: Merges) -> Self { + Self::builder() + .vocab_and_merges(vocab, merges) + .build() + .unwrap() + } + + fn bitfield_into_tokens(&self, bytes: &[u8], bitfield: BitField, count: usize) -> Vec { + let mut encoded = Vec::with_capacity(count); + let mut start = 0; + while start < bytes.len() { + let end = bitfield.successor(start + 1); + println!("bitfield's successor {:?}", &bytes[start..end]); + let token = self.find_token_by_bytes(&bytes[start..end]).expect(&format!("Could not convert bytes to tokens for bytes: [{:?}]", bytes.into_iter().map(|b| char::from(*b)).join("|"))); + encoded.push(token); + start = end; + } + encoded + } + + fn encode_into_bitfield(&self, bytes: &[u8]) -> (BitField, usize) { + // Reserve for every byte a bit in the bitfield. + let mut bitfield = BitField::new(bytes.len() + 1); + let mut heap = BinaryHeap::with_capacity(bytes.len() * 2); + heap.extend((0..bytes.len().saturating_sub(1)).filter_map(|i| { + self.find_token_by_bytes(&bytes[i..i + 2]) + .map(|e| Reverse((e, i as u32))) + })); + let mut count = bytes.len(); + while let Some(Reverse((token, start))) = heap.pop() { + let start = start as usize; + if !bitfield.is_set(start) { + continue; + } + let mid = bitfield.successor(start + 1); + if mid >= bytes.len() { + continue; + } + let end = bitfield.successor(mid + 1); + if self.token_len(token) != end - start { + continue; + } + bitfield.clear(mid); + count -= 1; + if end < bytes.len() { + let new_end = bitfield.successor(end + 1); + if let Some(e) = self.find_token_by_bytes(&bytes[start..new_end]) { + heap.push(Reverse((e, start as u32))); + } + } + if start > 0 { + let new_start = bitfield.predecessor(start - 1); + if let Some(e) = self.find_token_by_bytes(&bytes[new_start..end]) { + heap.push(Reverse((e, new_start as u32))); + } + } + } + (bitfield, count) + } + + pub fn encode_via_bitfield(&self, text: &[u8]) -> Vec { + let (bitfield, count) = self.encode_into_bitfield(text); + self.bitfield_into_tokens(text, bitfield, count) + } + + /// Construct a BytePairEncoding instance from an iterator that enumerates all tokens. + /// A suitable hash factor may be necessary to prevent hash collisions, which can be + /// found using [`find_hash_factor_for_dictionary`]. + /// + /// The recommended approach is to store the serialized value and reuse that, + /// to prevent repeating the cost of computing the hash factor and encoding. + pub fn from_dictionary( + tokens: impl IntoIterator>, + merges: Option, + hash_factor: Option, + ) -> Self { + let hash_factor = hash_factor + .inspect(|f| assert_ne!(*f, 0, "hash factor must be larger than zero")) + .unwrap_or(1); + let mut all_tokens = Vec::new(); + let mut all_tokens_rev = Vec::new(); + let mut token_starts = vec![0]; + let mut bytes_hash_to_token = FnvHashMap::default(); + let mut merge_map: HashMap = HashMap::new(); + for (i, token) in tokens.into_iter().enumerate() { + println!("token byte: {:?}, {i}", token); + bytes_hash_to_token.insert(hash_bytes(&token, hash_factor), i as u32); + all_tokens_rev.extend(token.iter().copied().rev()); + all_tokens.extend(token); + token_starts.push(all_tokens.len() as u32); + } + assert_eq!(bytes_hash_to_token.len() + 1, token_starts.len()); + let longest_searcher = DoubleArrayAhoCorasickBuilder::new() + .match_kind(aneubeck_daachorse::MatchKind::LeftmostLongest) + .build(token_iter(&all_tokens, &token_starts)) + .expect("failed to build AhoCorasick"); + + let overlapping_searcher = + DoubleArrayAhoCorasick::::new(token_iter(&all_tokens, &token_starts)).expect(""); + let overlapping_searcher_rev = + DoubleArrayAhoCorasick::::new(token_iter(&all_tokens, &token_starts)).expect(""); + + let next_prefix_match: Vec<_> = token_iter(&all_tokens, &token_starts) + .map(|token| { + next_match(&longest_searcher, &token[0..token.len() - 1]).unwrap_or(u32::MAX) + }) + .collect(); + + let vocab: HashMap = token_iter(&all_tokens, &token_starts) + .enumerate() + .map(|(id, item)| { + ( + unsafe { String::from_utf8_unchecked(Vec::from(item)) }, + id as u32, + ) + }) + .collect(); + + let vocab_r: HashMap = token_iter(&all_tokens, &token_starts) + .enumerate() + .map(|(id, item)| { + (id as u32, unsafe { + String::from_utf8_unchecked(Vec::from(item)) + }) + }) + .collect(); + + let mut split_table = vec![]; + let mut pair_lookup = FnvHashMap::default(); + + if let Some(ref merges) = merges { + for (index, pair) in merges.into_iter().enumerate() { + let token1 = &pair.0.clone(); + let token2 = &pair.1.clone(); + let id1 = vocab[token1]; + let id2 = vocab[token2]; + let new_token = format!("{}{}", token1, &token2); + let new_id = vocab + .get(&new_token) + .ok_or(Error::MergeTokenOutOfVocabulary(new_token)); + if let Ok(id) = new_id { + println!("{token1}, {token2}, {id1}, {id2}, {id}"); + pair_lookup.insert((id1, id2), *id); + split_table.push((id1, id2)); + merge_map.insert(Pair::from((id1, id2)), (index as u32, *id )); + }else{ + // gracefully error out + } + + // TODO wrong + } + split_table.push((merges.len() as u32, merges.len() as u32)); + } else { + // Reverse engineer the merge/split table. + for (id, token) in token_iter(&all_tokens, &token_starts).enumerate() { + let mut id1 = next_prefix_match[id]; + while id1 != u32::MAX { + let rest = &token[token_range(&token_starts, id1).len()..]; + if let Some(id2) = find_token_by_bytes( + &all_tokens, + &token_starts, + &bytes_hash_to_token, + rest, + hash_factor, + ) { + if id1 < id as u32 + && id2 < id as u32 + && is_valid_token_pair(&pair_lookup, &split_table, id1, id2) + { + pair_lookup.insert((id1, id2), id as u32); + split_table.push((id1, id2)); + merge_map.insert(Pair::from((id1, id2)), (id as u32, id as u32)); + break; + } + } + id1 = next_prefix_match[id1 as usize]; + } + if id1 == u32::MAX { + split_table.push((id as u32, id as u32)); + } + } + }; + + let bpe = Self { + all_tokens, + token_starts, + bytes_hash_to_token, + overlapping_searcher, + overlapping_searcher_rev, + longest_searcher, + next_prefix_match, + pair_lookup, + split_table, + hash_factor, + unk_token: None, + vocab, + vocab_r, + merges: merge_map, + }; + // for token_id in 0..bpe.num_tokens() as u32 { + // let bytes = bpe.token_bytes(token_id); + // let strs = bytes.iter().map(|b| char::from(*b)).collect::>(); + // println!("Encoding {bytes:?} into bitfield"); + // let tokens = bpe.encode_via_bitfield(bytes); + // assert_eq!( + // tokens, + // vec![token_id], + // "token {token_id} with bytes {bytes:?} (tokens {strs:?} encodes to {tokens:?} instead of to itself" + // ); + // } + bpe + } + + /// Initialize a BacktrackingBpeBuilder model from vocab and merges files + pub fn from_file(vocab: &str, merges: &str) -> BacktrackingBpeBuilder { + Self::builder().files(vocab.to_owned(), merges.to_owned()) + } + + /// Read the given files to extract the vocab and merges + pub fn read_file(vocab: &str, merges: &str) -> Result<(Vocab, Merges)> { + // Read vocab.json + let vocab_file = File::open(vocab)?; + let mut vocab_file = BufReader::new(vocab_file); + + let mut buffer = String::new(); + vocab_file.read_to_string(&mut buffer)?; + let json: Value = serde_json::from_str(&buffer)?; + let mut vocab = HashMap::new(); + match json { + Value::Object(m) => { + for (token, id) in m { + if let Value::Number(id) = id { + let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32; + vocab.insert(token, id); + } + } + } + _ => return Err(Box::new(Error::BadVocabulary)), + }; + + // Read merges file + let merge_file = File::open(merges)?; + let merge_file = BufReader::new(merge_file); + let merges = ResultShunt::process(merge_file.lines(), |iter| { + convert_merges_to_hashmap(iter, &vocab) + })??; // TODO correctly process to fill the split and pair lookup + + Ok((vocab, merges)) + } + + /// Return the number of tokens in this BPE dictionary. + pub fn num_tokens(&self) -> usize { + self.token_starts.len() - 1 + } + + /// Converts a token id into its corresponding token bytes. + /// Panics if the token_id is not within the valid 0..num_tokens() range! + pub fn token_bytes(&self, token_id: u32) -> &[u8] { + token_bytes(&self.all_tokens, &self.token_starts, token_id) + } + + pub(crate) fn is_valid_token_pair(&self, token1: u32, token2: u32) -> bool { + is_valid_token_pair(&self.pair_lookup, &self.split_table, token1, token2) + } + + /// Returns the length of the decoded byte slice of a token. + pub fn token_len(&self, token_id: u32) -> usize { + token_range(&self.token_starts, token_id).len() + } + + /// Returns the first longest match in the provided text. + pub(crate) fn next_match(&self, text: &[u8]) -> Option { + next_match(&self.longest_searcher, text) + } + + /// Returns the next token which shares the longest prefix with the specified token. + pub(crate) fn next_prefix(&self, token_id: u32) -> Option { + let prefix = self.next_prefix_match[token_id as usize]; + if prefix == u32::MAX { + None + } else { + Some(prefix) + } + } + + fn find_token_by_bytes(&self, bytes: &[u8]) -> Option { + find_token_by_bytes( + &self.all_tokens, + &self.token_starts, + &self.bytes_hash_to_token, + bytes, + self.hash_factor, + ) + } + + /// Decode a sequence of tokens back to its original byte sequence. + /// Note: we don't return here a str, since not every token sequence corresponds to a valid + /// utf8 sequence. + pub fn decode_tokens(&self, tokens: &[u32]) -> Vec { + let mut text = vec![]; + for token in tokens { + text.extend(self.token_bytes(*token)); + } + text + } + + /// Computes for every prefix of the input text a corresponding last token. + pub(crate) fn encode_all_prefixes(&self, text: &[u8]) -> Vec { + let mut last_token = Vec::with_capacity(text.len()); + let mut state = self.overlapping_searcher.start_state(); + for (pos, c) in text.iter().enumerate() { + let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c); + state = s; + for m in iter { + let new_token = m.value(); + let new_range = m.start()..m.end(); + assert_eq!(new_range.end, last_token.len() + 1); + if new_range.start == 0 { + last_token.push(new_token); + break; + } else { + let prev_token = unsafe { *last_token.get_unchecked(new_range.start - 1) }; + if self.is_valid_token_pair(prev_token, new_token) { + last_token.push(new_token); + break; + } + } + } + } + last_token + } + + /// Counts the number tokens produced when encoding the text. + pub fn count(&mut self, text: &[u8]) -> usize { + let mut enc = BacktrackState::new(text, None); + while self.step(&mut enc).is_some() {} + enc.count() + } + + pub fn encode_via_table(&self, text: &[u8]) -> Vec { + let last_token = self.encode_all_prefixes(text); + let mut encoded = Vec::with_capacity(text.len() / 3); + let mut pos = text.len(); + while pos > 0 { + let token = last_token[pos - 1]; + encoded.push(token); + pos -= self.token_len(token); + } + encoded.reverse(); + encoded + } + + pub fn encode_via_backtracking(&self, text: &[u8]) -> Vec { + let next_token = self.next_match(text); + let mut enc = BacktrackState::new(text, next_token); + while self.step(&mut enc).is_some() {} + enc.into_tokens() + } + + pub fn get_vocab(&self) -> Vocab { + self.vocab.clone() + } + + pub fn get_unk_token(&self) -> &Option { + &self.unk_token + } + + pub fn step(&self, backtrack_state: &mut BacktrackState) -> Option { + let mut token = backtrack_state.next_token?; + let last = backtrack_state.tokens.last().copied(); + loop { + let token_len = self.token_len(token); + let end_pos = backtrack_state.pos + token_len; + if backtrack_state.bitfield.is_set(end_pos) + && last + .map(|last_token| self.is_valid_token_pair(last_token, token)) + .unwrap_or(true) + { + backtrack_state.tokens.push(token); + backtrack_state.pos = end_pos; + // In principle, we could in some cases reuse the leftmost longest match iterator. + // Especially when it has to look ahead, this could save scanning the input multiple times. + // But on average this seems to be slower due to the overhead of storing the iterator as part of the struct. + backtrack_state.next_token = self.next_match(&backtrack_state.text[end_pos..]); + break; + } else if let Some(shorter) = self.next_prefix(token) { + token = shorter; + } else { + // Clearing the bitfield when we pop tokens saves a little bit of work... + backtrack_state.bitfield.clear(backtrack_state.pos); + backtrack_state.tokens.pop(); + backtrack_state.pos -= last.map(|t| self.token_len(t)).unwrap_or(0); + backtrack_state.next_token = last; + break; + } + } + backtrack_state.next_token + } + + fn word_to_tokens<'a, 'b: 'a>( + &'a self, + word: &'b Vec, + ) -> impl Iterator + 'a { + word.into_iter() + .map(move |id| Token::new(*id, self.vocab_r[&id].clone(), (0usize, 0usize))) + // TODO offsets should be easy to integrate as well! + } +} +impl Model for BacktrackingBpe { + type Trainer = BpeTrainer; + + fn get_vocab(&self) -> HashMap { + self.vocab.clone() + } + + fn get_vocab_size(&self) -> usize { + self.vocab.len() + } + + fn tokenize(&self, sequence: &str) -> Result> { + if sequence.is_empty() { + return Ok(vec![]); + } + let byte_text = sequence.as_bytes(); + let word = self.encode_via_backtracking(byte_text); + Ok(self.word_to_tokens(&word).collect()) + } + + fn token_to_id(&self, token: &str) -> Option { + self.vocab.get(token).copied() + } + + fn id_to_token(&self, id: u32) -> Option { + Some(self.vocab_r[&id].clone()) + } + + fn save(&self, folder: &Path, name: Option<&str>) -> Result> { + let vocab_file_name = match name { + Some(name) => format!("{name}-vocab.json"), + None => "vocab.json".to_string(), + }; + + // Write vocab.json + let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())] + .iter() + .collect(); + let mut vocab_file = File::create(&vocab_path)?; + let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r); + let serialized = serde_json::to_string(&order_vocab_iter)?; + vocab_file.write_all(serialized.as_bytes())?; + Ok(vec![vocab_path]) + // Ok(vec![vocab_path, merges_path]) + } + + fn get_trainer(&self) -> BpeTrainer { + BpeTrainer::default() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn my_example() { + let tokens = [ + "a", "b", "c", // 1 character each + "aac", "ac", "cc", "cca", "aacc", "aaccca", "acca", "acc", "aa", "aaa", + "aaaa", // 2 characters each + ]; + let mut bpe = + BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None, None); + // bpe.encode_via_backtracking(b"baacca"); + let tokens = bpe.tokenize("aaaacc").unwrap(); + println!("{:?}", bpe.tokenize("aaaacc")); + assert_eq!( + tokens, + vec![ + Token { + id: 12, + value: String::from("aaa"), + offsets: (0, 0) + }, + Token { + id: 10, + value: String::from("acc"), + offsets: (0, 0) + } + ] + ); + println!("{:?}", bpe.tokenize("baaaaccca")); + let tokens = bpe.tokenize("baaaaccca").unwrap(); + assert_eq!( + tokens, + vec![ + Token { + id: 1, + value: String::from("b"), + offsets: (0, 0) + }, + Token { + id: 12, + value: String::from("aaa"), + offsets: (0, 0) + }, + Token { + id: 4, + value: String::from("ac"), + offsets: (0, 0) + }, + Token { + id: 6, + value: String::from("cca"), + offsets: (0, 0) + } + ] + ); + bpe.encode_via_backtracking(b"baaaaccca"); + let tokens = [ + "a", "b", "c", // 1 character each + "acca", "cc", "ac", "aac", "cca", + ]; + let mut bpe = + BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None, None); + bpe.encode_via_backtracking(b"baacca"); + } +} diff --git a/tokenizers/src/models/backtracking_bpe/serialization.rs b/tokenizers/src/models/backtracking_bpe/serialization.rs new file mode 100644 index 000000000..dce6fb5ea --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/serialization.rs @@ -0,0 +1,228 @@ +use super::{ + super::bpe::Pair, super::OrderedVocabIter, convert_merges_to_hashmap, BacktrackingBpe, + BacktrackingBpeBuilder, +}; +use regex_syntax::ast::print; +use serde::{ + de::{Error, MapAccess, Visitor}, + ser::SerializeStruct, + Deserialize, Deserializer, Serialize, Serializer, +}; +use std::collections::HashMap; + +impl Serialize for BacktrackingBpe { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut model = serializer.serialize_struct("BPE", 8)?; + + // Start by small fields + model.serialize_field("type", "BPE")?; + + // Then the large ones + let mut merges: Vec<(&Pair, &u32)> = self + .merges + .iter() + .map(|(pair, (rank, _))| (pair, rank)) + .collect(); + merges.sort_unstable_by_key(|k| *k.1); + let merges = merges + .into_iter() + .map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone())) + .collect::>(); + let ordered_vocab = OrderedVocabIter::new(&self.vocab_r); + + model.serialize_field("vocab", &ordered_vocab)?; + model.serialize_field("merges", &merges)?; + + model.end() + } +} + +impl<'de> Deserialize<'de> for BacktrackingBpe { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_struct( + "BPE", + &["type", "dropout", "unk_token", "vocab", "merges"], + BacktrackingBpeVisitor, + ) + } +} + +struct BacktrackingBpeVisitor; +impl<'de> Visitor<'de> for BacktrackingBpeVisitor { + type Value = BacktrackingBpe; + + fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(fmt, "struct BacktrackingBpe to be the type") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: MapAccess<'de>, + { + let mut builder = BacktrackingBpeBuilder::new(); + let mut vocab: Option> = None; + + #[derive(Debug, Deserialize)] + #[serde(untagged)] + enum MergeType { + Tuple(Vec<(String, String)>), + Legacy(Vec), + } + let mut merges: Option = None; + while let Some(key) = map.next_key::()? { + match key.as_ref() { + "dropout" => { + if let Some(dropout) = map.next_value()? { + builder = builder.dropout(dropout); + } + } + "unk_token" => { + if let Some(unk) = map.next_value()? { + builder = builder.unk_token(unk); + } + } + "vocab" => vocab = Some(map.next_value()?), + "merges" => merges = Some(map.next_value()?), + "type" => match map.next_value()? { + "BacktrackingBpe" => {} + "BPE" => {println!("Type is BPE but initializing a backtracking BPE")} + u => { + return Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(u), + &"BacktrackingBpe should have been found", + )) + } + }, + field => { + println!("Ignoring unused field {:?}", field); // TODO make it into a logger + // Ensure the value is consumed to maintain valid deserialization + let _ = map.next_value::()?; + } + } + } + if let (Some(vocab), Some(merges)) = (vocab, merges) { + let merges = match merges { + MergeType::Tuple(merges) => merges, + MergeType::Legacy(merges) => { + convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(|e| Error::custom("Error in convert merges to hashmap"))? + } + }; + builder = builder.vocab_and_merges(vocab, merges); + let model = builder.build().map_err(|e| Error::custom(format!("Error building the backtraciing BPE {:?}", e)))?; + println!("{:?}", model); + Ok(model) + } else { + Err(Error::custom("Missing vocab/merges")) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::models::bpe::Vocab; + + #[test] + fn test_serialization() { + let bpe_string = r#"{ + "type": "BPE", + "dropout": null, + "unk_token": "", + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": true, + "vocab": { + "a": 1, + "b c d": 2, + "ab c d": 3 + }, + "merges": [ + ["a", "b c d"] + ] + }"#; + let reconstructed: Result = serde_json::from_str(&bpe_string); + println!("End of my example"); + + + let vocab: Vocab = [ + ("a".into(), 1), + ("b".into(), 2), + ("ab".into(), 3), + ] + .iter() + .cloned() + .collect(); + let bpe = BacktrackingBpeBuilder::default() + .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())]) + .unk_token("".to_string()) + .build() + .unwrap(); + + match reconstructed { + Ok(reconstructed) => { + println!("Good"); + println!("{:?}", reconstructed.encode_via_backtracking(b"aab c d")); + assert_eq!(bpe, reconstructed); + } + Err(err) => { + println!("Error deserializing: {:?}", err); + + } + } + + let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"","fuse_unk":false,"byte_fallback":false,"vocab":{"a":1,"b":2,"ab":3},"merges":["a b"]}"#; + let legacy = serde_json::from_str(legacy); + match legacy { + Ok(_) => { + println!("Good"); + assert_eq!(bpe, legacy.unwrap()); + } + Err(err) => { + println!("Error: {:?}", err); + } + } + + + let data = serde_json::to_string(&bpe).unwrap(); + assert_eq!( + data, + r#"{"type":"BPE","vocab":{"ab":0,"a":1,"b":2},"merges":[["a","b"]]}"# + ); + let reconstructed = serde_json::from_str(&data).unwrap(); + assert_eq!(bpe, reconstructed); // TODO failing for now! + + // With a space in the token + let vocab: Vocab = [ + ("".into(), 0), + ("a".into(), 1), + ("b c d".into(), 2), + ("ab c d".into(), 3), + ] + .iter() + .cloned() + .collect(); + let bpe = BacktrackingBpeBuilder::default() + .vocab_and_merges(vocab, vec![("a".to_string(), "b c d".to_string())]) + .unk_token("".to_string()) + .build() + .unwrap(); + let data = serde_json::to_string(&bpe).unwrap(); + assert_eq!( + data, + r#"{"type":"BacktrackingBpe","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b c d":2,"ab c d":3},"merges":[["a","b c d"]]}"# + ); + let reconstructed = serde_json::from_str(&data).unwrap(); + assert_eq!(bpe, reconstructed); + + + + } +} diff --git a/tokenizers/src/models/backtracking_bpe/trainer.rs b/tokenizers/src/models/backtracking_bpe/trainer.rs new file mode 100644 index 000000000..94270c374 --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/trainer.rs @@ -0,0 +1,856 @@ +#![allow(clippy::map_entry)] + +use super::{ + super::bpe::Pair, super::bpe::WithFirstLastIterator, super::bpe::Word, BacktrackingBpe, +}; +use crate::parallelism::*; +use crate::tokenizer::{AddedToken, Result, Trainer}; +use crate::utils::progress::{ProgressBar, ProgressStyle}; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::collections::{BinaryHeap, HashMap, HashSet}; + +#[derive(Debug, Eq)] +struct Merge { + pair: Pair, + count: u64, + pos: HashSet, +} +impl PartialEq for Merge { + fn eq(&self, other: &Self) -> bool { + self.count == other.count && self.pair == other.pair + } +} +impl PartialOrd for Merge { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for Merge { + fn cmp(&self, other: &Self) -> Ordering { + if self.count != other.count { + self.count.cmp(&other.count) + } else { + // Here we want ascending order + other.pair.cmp(&self.pair) + } + } +} + +struct Config { + min_frequency: u64, + vocab_size: usize, + show_progress: bool, + special_tokens: Vec, + limit_alphabet: Option, + initial_alphabet: HashSet, + continuing_subword_prefix: Option, + end_of_word_suffix: Option, + max_token_length: Option, +} + +/// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom +/// configuration. +pub struct BpeTrainerBuilder { + config: Config, +} + +impl Default for BpeTrainerBuilder { + fn default() -> Self { + Self { + config: Config { + min_frequency: 0, + vocab_size: 30000, + show_progress: true, + special_tokens: vec![], + limit_alphabet: None, + initial_alphabet: HashSet::new(), + continuing_subword_prefix: None, + end_of_word_suffix: None, + max_token_length: None, + }, + } + } +} + +impl BpeTrainerBuilder { + /// Constructs a new `BpeTrainerBuilder` + pub fn new() -> Self { + Self::default() + } + + /// Set the expected minimum frequency + #[must_use] + pub fn min_frequency(mut self, frequency: u64) -> Self { + self.config.min_frequency = frequency; + self + } + + /// Set the vocabulary size + #[must_use] + pub fn vocab_size(mut self, size: usize) -> Self { + self.config.vocab_size = size; + self + } + + /// Set whether to show progress + #[must_use] + pub fn show_progress(mut self, show: bool) -> Self { + self.config.show_progress = show; + self + } + + /// Set the special tokens + #[must_use] + pub fn special_tokens(mut self, tokens: Vec) -> Self { + self.config.special_tokens = tokens; + self + } + + /// Set whether to limit the alphabet + #[must_use] + pub fn limit_alphabet(mut self, limit: usize) -> Self { + self.config.limit_alphabet = Some(limit); + self + } + + /// Set the initial alphabet + #[must_use] + pub fn initial_alphabet(mut self, alphabet: HashSet) -> Self { + self.config.initial_alphabet = alphabet; + self + } + + /// Set the continuing_subword_prefix + #[must_use] + pub fn continuing_subword_prefix(mut self, prefix: String) -> Self { + self.config.continuing_subword_prefix = Some(prefix); + self + } + + /// Set the end_of_word_suffix + #[must_use] + pub fn end_of_word_suffix(mut self, suffix: String) -> Self { + self.config.end_of_word_suffix = Some(suffix); + self + } + /// Set max_token_length + #[must_use] + pub fn max_token_length(mut self, max_token_length: Option) -> Self { + self.config.max_token_length = max_token_length; + self + } + + /// Constructs the final BpeTrainer + pub fn build(self) -> BacktrackingBpeTrainer { + BacktrackingBpeTrainer { + min_frequency: self.config.min_frequency, + vocab_size: self.config.vocab_size, + show_progress: self.config.show_progress, + special_tokens: self.config.special_tokens, + limit_alphabet: self.config.limit_alphabet, + initial_alphabet: self.config.initial_alphabet, + continuing_subword_prefix: self.config.continuing_subword_prefix, + end_of_word_suffix: self.config.end_of_word_suffix, + max_token_length: self.config.max_token_length, + words: HashMap::new(), + } + } +} + +/// In charge of training a `BacktrackingBpe` model +/// +/// # Examples +/// +/// ``` +/// use tokenizers::tokenizer::Trainer; +/// use tokenizers::models::backtracking_bpe::{BacktrackingBpe, BacktrackingBpeTrainer}; +/// +/// let sequences = vec![ "Hello", "World" ]; +/// +/// let mut trainer = BacktrackingBpeTrainer::default(); +/// trainer.feed(sequences.iter(), |s| Ok(vec![s.to_owned()])); +/// +/// let mut model = BacktrackingBpe::default(); +/// let special_tokens = trainer.train(&mut model).unwrap(); +/// ``` +#[non_exhaustive] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] +pub struct BacktrackingBpeTrainer { + /// The minimum frequency a pair must have to produce a merge operation + pub min_frequency: u64, + /// The target vocabulary size + pub vocab_size: usize, + /// Whether to show progress while training + pub show_progress: bool, + /// A list of special tokens that the model should know of + pub special_tokens: Vec, + /// Whether to limit the number of initial tokens that can be kept before computing merges + pub limit_alphabet: Option, + /// The initial alphabet we want absolutely to include. This allows to cover + /// some characters that are not necessarily in the training set + pub initial_alphabet: HashSet, + /// An optional prefix to use on any subword that exist only behind another one + pub continuing_subword_prefix: Option, + /// An optional suffix to caracterize and end-of-word subword + pub end_of_word_suffix: Option, + /// An optional parameter to limit the max length of any single token + pub max_token_length: Option, + + words: HashMap, +} + +impl Default for BacktrackingBpeTrainer { + fn default() -> Self { + Self::builder().build() + } +} + +impl BacktrackingBpeTrainer { + pub fn new(min_frequency: u64, vocab_size: usize) -> Self { + Self { + min_frequency, + vocab_size, + ..Default::default() + } + } + + pub fn builder() -> BpeTrainerBuilder { + BpeTrainerBuilder::new() + } + + /// Setup a progress bar if asked to show progress + fn setup_progress(&self) -> Option { + if self.show_progress { + let p = ProgressBar::new(0); + p.set_style( + ProgressStyle::default_bar() + .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {pos:<9!}/{len:>9!}") + .expect("Invalid progress template"), + ); + Some(p) + } else { + None + } + } + + /// Set the progress bar in the finish state + fn finalize_progress(&self, p: &Option, final_len: usize) { + if let Some(p) = p { + p.set_length(final_len as u64); + p.finish(); + println!(); + } + } + + /// Update the progress bar with the new provided length and message + fn update_progress(&self, p: &Option, len: usize, message: &'static str) { + if let Some(p) = p { + p.set_message(message); + p.set_length(len as u64); + p.reset(); + } + } + + /// Add the provided special tokens to the initial vocabulary + fn add_special_tokens(&self, w2id: &mut HashMap, id2w: &mut Vec) { + for token in &self.special_tokens { + if !w2id.contains_key(&token.content) { + id2w.push(token.content.to_owned()); + w2id.insert(token.content.to_owned(), (id2w.len() - 1) as u32); + } + } + } + + /// Compute the initial alphabet and limit it if relevant + fn compute_alphabet( + &self, + wc: &HashMap, + w2id: &mut HashMap, + id2w: &mut Vec, + ) { + // Compute the alphabet from seen words + let mut alphabet: HashMap = HashMap::new(); + for (word, count) in wc { + for c in word.chars() { + alphabet + .entry(c) + .and_modify(|cnt| *cnt += *count as usize) + .or_insert(*count as usize); + } + } + + // Also include anything from the provided initial alphabet + for c in &self.initial_alphabet { + alphabet + .entry(*c) + .and_modify(|cnt| *cnt = usize::MAX) + .or_insert(usize::MAX); + } + + let mut kept = alphabet.iter().collect::>(); + + // Compute the number of chars to remove from the alphabet + // If `limit_alphabet < initial_alphabet.len()`, some of these initial characters + // will be removed + let to_remove = self + .limit_alphabet + .map(|limit| { + if alphabet.len() > limit { + alphabet.len() - limit + } else { + 0 + } + }) + .unwrap_or(0); + + // Remove the unwanted chars + if to_remove > 0 { + kept.sort_unstable_by_key(|k| *k.1); + kept.drain(..to_remove); + } + + // Keep the initial alphabet (sorted for determinism) + kept.sort_unstable_by_key(|k| (*k.0) as u32); + kept.into_iter().for_each(|(c, _)| { + let s = c.to_string(); + if !w2id.contains_key(&s) { + id2w.push(s.clone()); + w2id.insert(s, (id2w.len() - 1) as u32); + } + }); + } + + /// Tokenize words and add subwords to the vocabulary when relevant + fn tokenize_words( + &self, + wc: &HashMap, + w2id: &mut HashMap, + id2w: &mut Vec, + p: &Option, + ) -> (Vec, Vec) { + let mut words: Vec = Vec::with_capacity(wc.len()); + let mut counts: Vec = Vec::with_capacity(wc.len()); + + for (word, count) in wc { + let mut current_word = Word::new(); + counts.push(*count); + + for (is_first, is_last, c) in word.chars().with_first_and_last() { + let mut s = c.to_string(); + if w2id.contains_key(&s) { + // Found the initial char in the authorized alphabet + + // Add the `continuing_subword_prefix` if relevant + if !is_first { + if let Some(prefix) = &self.continuing_subword_prefix { + s = format!("{prefix}{s}"); + } + } + // Add the `end_of_word_suffix` if relevant + if is_last { + if let Some(suffix) = &self.end_of_word_suffix { + s = format!("{s}{suffix}"); + } + } + + // Insert the new formed string if necessary + if !w2id.contains_key(&s) { + id2w.push(s.clone()); + w2id.insert(s.clone(), (id2w.len() - 1) as u32); + } + current_word.add(w2id[&s], 1); // We do not care about the len here + } + } + words.push(current_word); + + if let Some(p) = p { + p.inc(1); + } + } + + (words, counts) + } + + fn count_pairs( + &self, + words: &[Word], + counts: &[u64], + p: &Option, + ) -> (HashMap, HashMap>) { + words + .maybe_par_iter() + .enumerate() + .map(|(i, word)| { + let mut pair_counts = HashMap::new(); + let mut where_to_update: HashMap> = HashMap::new(); + + for window in word.get_chars().windows(2) { + let cur_pair: Pair = (window[0], window[1]); + + // Initialize pair_counts and where_to_update for this pair if we just saw it + if !pair_counts.contains_key(&cur_pair) { + pair_counts.insert(cur_pair, 0); + } + + // Then update counts + let count = counts[i]; + where_to_update + .entry(cur_pair) + .and_modify(|h| { + h.insert(i); + }) + .or_insert_with(|| { + let mut h = HashSet::new(); + h.insert(i); + h + }); + *pair_counts.get_mut(&cur_pair).unwrap() += count as i32; + } + + if let Some(p) = &p { + p.inc(1); + } + + (pair_counts, where_to_update) + }) + .reduce( + || (HashMap::new(), HashMap::new()), + |(mut pair_counts, mut where_to_update), (pc, wtu)| { + for (k, v) in pc { + pair_counts.entry(k).and_modify(|c| *c += v).or_insert(v); + } + for (k, v) in wtu { + where_to_update + .entry(k) + .and_modify(|set| *set = set.union(&v).copied().collect()) + .or_insert(v); + } + (pair_counts, where_to_update) + }, + ) + } + + pub fn do_train( + &self, + word_counts: &HashMap, + model: &mut BacktrackingBpe, + ) -> Result> { + let mut word_to_id: HashMap = HashMap::with_capacity(self.vocab_size); + let mut id_to_word: Vec = Vec::with_capacity(self.vocab_size); + let max_token_length: usize = self.max_token_length.unwrap_or(usize::MAX); + + let progress = self.setup_progress(); + + // + // 1. Add all special tokens to the vocabulary + // + self.add_special_tokens(&mut word_to_id, &mut id_to_word); + + // + // 2. Compute the initial alphabet + // + self.compute_alphabet(word_counts, &mut word_to_id, &mut id_to_word); + + // + // 3. Tokenize words + // + self.update_progress(&progress, word_counts.len(), "Tokenize words"); + let (mut words, counts) = + self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress); + self.finalize_progress(&progress, words.len()); + + // + // 4. Count pairs in words + // + self.update_progress(&progress, words.len(), "Count pairs"); + let (mut pair_counts, mut where_to_update) = self.count_pairs(&words, &counts, &progress); + // Insert them in the queue + let mut queue = BinaryHeap::with_capacity(pair_counts.len()); + where_to_update.drain().for_each(|(pair, pos)| { + let count = pair_counts[&pair]; + if count > 0 { + queue.push(Merge { + pair, + count: count as u64, + pos, + }); + } + }); + self.finalize_progress(&progress, words.len()); + + // + // 5. Do merges + // + self.update_progress(&progress, self.vocab_size, "Compute merges"); + let mut merges: Vec<(Pair, u32)> = vec![]; + loop { + // Stop as soon as we have a big enough vocabulary + if word_to_id.len() >= self.vocab_size { + break; + } + + if queue.is_empty() { + break; + } + + let mut top = queue.pop().unwrap(); + if top.count != pair_counts[&top.pair] as u64 { + top.count = pair_counts[&top.pair] as u64; + queue.push(top); + continue; + } + + if top.count < 1 || self.min_frequency > top.count { + break; + } + + let part_a = &id_to_word[top.pair.0 as usize]; + let mut part_b = id_to_word[top.pair.1 as usize].to_owned(); + + // Build new token + if let Some(prefix) = &self.continuing_subword_prefix { + if part_b.starts_with(prefix) { + let prefix_byte_len = prefix.chars().map(|c| c.len_utf8()).sum(); + part_b = part_b[prefix_byte_len..].to_string(); + } + } + let new_token = format!("{part_a}{part_b}"); + // implement sentencepiece-like merge. + // if this code were to be merged, integrate a way in the python bindings to communicate this variable + // default should be 0/None to maintain previous behavior. 16 is the spm default. + + // Insert new token if it does not already exist + let new_token_id = word_to_id + .get(&new_token) + .copied() + .unwrap_or(id_to_word.len() as u32); + if !word_to_id.contains_key(&new_token) { + id_to_word.push(new_token.clone()); + word_to_id.insert(new_token.clone(), new_token_id); + } + merges.push((top.pair, new_token_id)); + + // Merge the new pair in every words + // Safety: This is just a type assertion, the code below may no longer be safe + // if the type of `pos` changes + let pos: &HashSet = &top.pos; + + let words_len = words.len(); + struct WordPtr(*mut Word); + // Safety: We do not actually use this for concurrent access to the same memory, + // only to different chunks within the same allocation. + unsafe impl Sync for WordPtr {} + let word_start = WordPtr(words.as_mut_ptr()); + + let changes = pos + .maybe_par_iter() + .flat_map(|&i| { + // Safety: + // We are producing a valid pointer since we are indexing in bounds + // + // We can access each `word` here in parallel because each position + // can be there only once (pos is a HashSet). + unsafe { + assert!(i < words_len); + // This is words[i], but avoids needing to go through &T (which triggers UB) + let word = word_start.0.add(i); + // let word: &mut Word = &mut (*word); + (*word) + .merge(top.pair.0, top.pair.1, new_token_id, max_token_length) + .into_iter() + .map(|c| (c, i)) + .collect::>() + } + }) + .collect::>(); + + // Introduce new formed pairs + for ((pair, change), iw) in changes { + let count = change * counts[iw] as i32; + pair_counts + .entry(pair) + .and_modify(|c| *c += count) + .or_insert(count); + if change > 0 { + where_to_update + .entry(pair) + .and_modify(|h| { + h.insert(iw); + }) + .or_insert_with(|| { + let mut h = HashSet::new(); + h.insert(iw); + h + }); + } + } + where_to_update.drain().for_each(|(pair, pos)| { + let count = pair_counts[&pair]; + if count > 0 { + queue.push(Merge { + pair, + count: count as u64, + pos, + }); + } + }); + + if let Some(p) = &progress { + p.inc(1); + } + } + self.finalize_progress(&progress, merges.len()); + + // Transfer new vocab & options to model + model.vocab = word_to_id; + model.vocab_r = model + .vocab + .iter() + .map(|(key, val)| (*val, key.to_owned())) + .collect(); + model.merges = merges + .into_iter() + .enumerate() + .map(|(i, (pair, new_token_id))| (pair, (i as u32, new_token_id))) + .collect(); + + Ok(self.special_tokens.clone()) + } +} + +impl Trainer for BacktrackingBpeTrainer { + type Model = BacktrackingBpe; + + /// Train a BacktrackingBpe model + fn train(&self, model: &mut BacktrackingBpe) -> Result> { + self.do_train(&self.words, model) + } + + /// Whether we should show progress + fn should_show_progress(&self) -> bool { + self.show_progress + } + + fn feed(&mut self, iterator: I, process: F) -> Result<()> + where + I: Iterator + Send, + S: AsRef + Send, + F: Fn(&str) -> Result> + Sync, + { + let words: Result> = iterator + .maybe_par_bridge() + .map(|sequence| { + let words = process(sequence.as_ref())?; + let mut map = HashMap::new(); + for word in words { + map.entry(word).and_modify(|c| *c += 1).or_insert(1); + } + Ok(map) + }) + .reduce( + || Ok(HashMap::new()), + |acc, ws| { + let mut acc = acc?; + for (k, v) in ws? { + acc.entry(k).and_modify(|c| *c += v).or_insert(v); + } + Ok(acc) + }, + ); + + self.words = words?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::{BacktrackingBpe, BacktrackingBpeTrainer, Pair}; + use std::collections::HashMap; + + #[test] + fn test_train() { + let word_counts: HashMap = [ + ("roses".into(), 1), + ("are".into(), 2), + ("red".into(), 1), + ("voilets".into(), 1), + ("blue".into(), 1), + ("BERT".into(), 1), + ("is".into(), 2), + ("big".into(), 1), + ("and".into(), 1), + ("so".into(), 1), + ("GPT-2".into(), 1), + ] + .iter() + .cloned() + .collect(); + let trainer = BacktrackingBpeTrainer::builder() + .show_progress(false) + .min_frequency(2) + .build(); + let mut model = BacktrackingBpe::default(); + trainer.do_train(&word_counts, &mut model).unwrap(); + + // Vocab should contain all of the characters from the `word_counts` mapping + // as well as three merges: 're', 'are', and 'is'. + let expected_vocab: HashMap = [ + ("-".into(), 0), + ("2".into(), 1), + ("B".into(), 2), + ("E".into(), 3), + ("G".into(), 4), + ("P".into(), 5), + ("R".into(), 6), + ("T".into(), 7), + ("a".into(), 8), + ("b".into(), 9), + ("d".into(), 10), + ("e".into(), 11), + ("g".into(), 12), + ("i".into(), 13), + ("l".into(), 14), + ("n".into(), 15), + ("o".into(), 16), + ("r".into(), 17), + ("s".into(), 18), + ("t".into(), 19), + ("u".into(), 20), + ("v".into(), 21), + ("re".into(), 22), + ("are".into(), 23), + ("is".into(), 24), + ] + .iter() + .cloned() + .collect(); + assert_eq!(model.vocab, expected_vocab); + + // The keys in `merges` are pairs of symbols, the values are tuples of (rank, id), + // where 'rank' determines the order in which this merge will be applied during + // tokenization, and 'id' is the vocab id of the symbol resulting from merging + // the pair of symbols in the corresponding key. + let expected_merges: HashMap = [ + ((17, 11), (0, 22)), // 'r' + 'e' -> 're' + ((8, 22), (1, 23)), // 'a' + 're' -> 'are' + ((13, 18), (2, 24)), // 'i' + 's' -> 'is' + ] + .iter() + .cloned() + .collect(); + assert_eq!(model.merges, expected_merges); + } + #[test] + fn bpe_test_max_token_length_16() { + /* bpe_test_max_token_length series of tests test the max_token_length flag of bpetrainer + // this is the more robust version that only tests max length of learned tokens + // (pre) tokenizer settings or vocab can be easily modified when necessary + */ + + let max_token_length = 16; + let long_word_counts: HashMap = [ + ("singlelongtokenwithoutcasechange", 2), + ("singleLongTokenWithCamelCaseChange", 2), + ("Longsingletokenwithpunctu@t!onwithin", 2), + ("Anotherlongsingletokenwithnumberw1th1n", 2), + ("짧은한글문자열짧은한", 2), // korean 10 char + ("긴한글문자열긴한글문자열긴한글문", 2), // korean 16 char + ("短字符串短字符串短字", 2), //simplified chinese 10 char + ("长字符串长字符串长字符串长字符串", 2), // simp. chinese 16 char + ("短い文字列短い文字列", 2), // japanese 10 char + ("長い文字列長い文字列長い文字列長", 2), // japanese 16 char + ("so", 2), + ("GPT-2", 2), + ] + .iter() + .map(|(key, value)| (key.to_string(), *value)) + .collect(); + let trainer = BacktrackingBpeTrainer::builder() + .max_token_length(Some(max_token_length)) + .show_progress(false) + .min_frequency(0) + .build(); + let mut model = BacktrackingBpe::default(); + trainer.do_train(&long_word_counts, &mut model).unwrap(); + let vocab = model.get_vocab(); + for token in vocab.keys() { + assert!( + token.chars().count() <= max_token_length, + "token too long : {} , chars().count() = {}", + token, + token.chars().count() + ) + } + } + #[test] + fn bpe_test_max_token_length_direct_assert() { + /* more direct version of bpe_test_max_token_length test + // directly compares tokens with known expected values. + // maybe unstable depending on specific settings or changes. + */ + let long_word_counts: HashMap = [ + ("sin", 2), + ("Sin", 2), + ("Lon", 2), + ("Ano", 2), + ("짧은한", 2), + ("긴한글", 2), + ("短字符", 2), + ("长字符", 2), + ("短い文", 2), + ("長い文", 2), + ("so", 2), + ("GP", 2), + ] + .iter() + .map(|(key, value)| (key.to_string(), *value)) + .collect(); + let trainer = BacktrackingBpeTrainer::builder() + .max_token_length(Some(2)) + .show_progress(false) + .min_frequency(0) + .build(); + let mut model = BacktrackingBpe::default(); + trainer.do_train(&long_word_counts, &mut model).unwrap(); + let trained_vocab: HashMap = model.get_vocab(); + let expected_vocab: HashMap = [ + ("短", 12), + ("n", 6), + ("i", 5), + ("s", 8), + ("字符", 23), + ("長", 14), + ("긴", 17), + ("い文", 22), + ("L", 2), + ("in", 21), + ("o", 7), + ("은한", 29), + ("S", 4), + ("P", 3), + ("so", 27), + ("符", 13), + ("文", 11), + ("字", 10), + ("짧", 19), + ("GP", 25), + ("글", 16), + ("G", 1), + ("An", 24), + ("长", 15), + ("A", 0), + ("Lo", 26), + ("긴한", 28), + ("い", 9), + ("한", 20), + ("은", 18), + ] + .iter() + .cloned() + .map(|(k, v)| (k.to_string(), v)) + .collect(); + assert_eq!(trained_vocab, expected_vocab) + } +} diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index f0d40b2df..97d8e12ad 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -4,9 +4,9 @@ use std::{iter, mem}; mod model; mod serialization; pub mod trainer; -mod word; +pub mod word; -type Pair = (u32, u32); +pub(crate) type Pair = (u32, u32); /// Errors that can be encountered while using or constructing a `BPE` model. #[derive(thiserror::Error, Debug)] @@ -79,4 +79,4 @@ where // Re-export pub use model::*; pub use trainer::*; -use word::*; +pub use word::*; diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 6fc8033e3..32720a126 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -53,7 +53,7 @@ impl Symbol { } #[derive(Clone, Default)] -pub(super) struct Word { +pub struct Word { symbols: Vec, } impl std::fmt::Debug for Word { @@ -74,7 +74,7 @@ impl std::fmt::Debug for Word { } impl Word { - pub(super) fn new() -> Self { + pub(crate) fn new() -> Self { Word { symbols: vec![] } } @@ -84,7 +84,7 @@ impl Word { } } - pub(super) fn add(&mut self, c: u32, byte_len: usize) { + pub(crate) fn add(&mut self, c: u32, byte_len: usize) { let (prev, next) = { let len = self.symbols.len() as isize; if let Some(last) = self.symbols.last_mut() { @@ -103,7 +103,7 @@ impl Word { }); } - pub(super) fn merge( + pub(crate) fn merge( &mut self, c1: u32, c2: u32, @@ -251,7 +251,7 @@ impl Word { self.symbols.retain(|s| s.len != 0); } - pub(super) fn get_chars(&self) -> Vec { + pub(crate) fn get_chars(&self) -> Vec { self.symbols.iter().map(|s| s.c).collect() } diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 3ab3b495b..818c8f35f 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -1,5 +1,6 @@ //! Popular tokenizer models. +pub mod backtracking_bpe; pub mod bpe; pub mod unigram; pub mod wordlevel; @@ -10,6 +11,7 @@ use std::path::{Path, PathBuf}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use crate::models::backtracking_bpe::{BacktrackingBpe, BacktrackingBpeTrainer}; use crate::models::bpe::{BpeTrainer, BPE}; use crate::models::unigram::{Unigram, UnigramTrainer}; use crate::models::wordlevel::{WordLevel, WordLevelTrainer}; @@ -61,6 +63,7 @@ impl<'a> Serialize for OrderedVocabIter<'a> { #[serde(untagged)] pub enum ModelWrapper { BPE(BPE), + BacktrackingBpe(BacktrackingBpe), // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility // with the versions not including the "type"), since WordLevel is a subset of WordPiece WordPiece(WordPiece), @@ -86,6 +89,7 @@ impl<'de> Deserialize<'de> for ModelWrapper { WordPiece, WordLevel, Unigram, + BacktrackingBpe, } #[derive(Deserialize)] @@ -121,6 +125,9 @@ impl<'de> Deserialize<'de> for ModelWrapper { EnumType::Unigram => ModelWrapper::Unigram( serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?, ), + EnumType::BacktrackingBpe => ModelWrapper::BacktrackingBpe( + serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?, + ), }, ModelHelper::Legacy(value) => { let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?; @@ -139,6 +146,7 @@ impl_enum_from!(WordLevel, ModelWrapper, WordLevel); impl_enum_from!(WordPiece, ModelWrapper, WordPiece); impl_enum_from!(BPE, ModelWrapper, BPE); impl_enum_from!(Unigram, ModelWrapper, Unigram); +impl_enum_from!(BacktrackingBpe, ModelWrapper, BacktrackingBpe); impl Model for ModelWrapper { type Trainer = TrainerWrapper; @@ -149,6 +157,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.tokenize(tokens), Self::BPE(t) => t.tokenize(tokens), Self::Unigram(t) => t.tokenize(tokens), + Self::BacktrackingBpe(t) => t.tokenize(tokens), } } @@ -158,6 +167,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.token_to_id(token), Self::BPE(t) => t.token_to_id(token), Self::Unigram(t) => t.token_to_id(token), + Self::BacktrackingBpe(t) => t.token_to_id(token), } } @@ -167,6 +177,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.id_to_token(id), Self::BPE(t) => t.id_to_token(id), Self::Unigram(t) => t.id_to_token(id), + Self::BacktrackingBpe(t) => t.id_to_token(id), } } @@ -176,6 +187,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.get_vocab(), Self::BPE(t) => t.get_vocab(), Self::Unigram(t) => t.get_vocab(), + Self::BacktrackingBpe(t) => t.get_vocab(), } } @@ -185,6 +197,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.get_vocab_size(), Self::BPE(t) => t.get_vocab_size(), Self::Unigram(t) => t.get_vocab_size(), + Self::BacktrackingBpe(t) => t.get_vocab_size(), } } @@ -194,6 +207,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.save(folder, name), Self::BPE(t) => t.save(folder, name), Self::Unigram(t) => t.save(folder, name), + Self::BacktrackingBpe(t) => t.save(folder, name), } } @@ -203,6 +217,7 @@ impl Model for ModelWrapper { Self::WordPiece(t) => t.get_trainer().into(), Self::BPE(t) => t.get_trainer().into(), Self::Unigram(t) => t.get_trainer().into(), + Self::BacktrackingBpe(t) => t.get_trainer().into(), } } } @@ -230,6 +245,7 @@ pub enum TrainerWrapper { WordPieceTrainer(WordPieceTrainer), WordLevelTrainer(WordLevelTrainer), UnigramTrainer(UnigramTrainer), + BacktrackingBpeTrainer(BacktrackingBpeTrainer), } impl Trainer for TrainerWrapper { @@ -241,6 +257,7 @@ impl Trainer for TrainerWrapper { Self::WordPieceTrainer(wpt) => wpt.should_show_progress(), Self::WordLevelTrainer(wpt) => wpt.should_show_progress(), Self::UnigramTrainer(wpt) => wpt.should_show_progress(), + Self::BacktrackingBpeTrainer(wpt) => wpt.should_show_progress(), } } @@ -262,6 +279,10 @@ impl Trainer for TrainerWrapper { ModelWrapper::Unigram(u) => t.train(u), _ => Err("UnigramTrainer can only train a Unigram".into()), }, + Self::BacktrackingBpeTrainer(t) => match model { + ModelWrapper::BacktrackingBpe(bpe) => t.train(bpe), + _ => Err("BpeTrainer can only train a BPE".into()), + }, } } @@ -276,6 +297,7 @@ impl Trainer for TrainerWrapper { Self::WordPieceTrainer(wpt) => wpt.feed(iterator, process), Self::WordLevelTrainer(wpt) => wpt.feed(iterator, process), Self::UnigramTrainer(wpt) => wpt.feed(iterator, process), + Self::BacktrackingBpeTrainer(wpt) => wpt.feed(iterator, process), } } } @@ -284,6 +306,11 @@ impl_enum_from!(BpeTrainer, TrainerWrapper, BpeTrainer); impl_enum_from!(WordPieceTrainer, TrainerWrapper, WordPieceTrainer); impl_enum_from!(UnigramTrainer, TrainerWrapper, UnigramTrainer); impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer); +impl_enum_from!( + BacktrackingBpeTrainer, + TrainerWrapper, + BacktrackingBpeTrainer +); #[cfg(test)] mod tests {