Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft backtrack #1712

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ impl PyModel {
let base = self.clone();
Ok(match *self.model.as_ref().read().unwrap() {
ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base))?.into_py(py),
ModelWrapper::BacktrackingBpe(_) => Py::new(py, (PyBacktrackingBpe {}, base))?.into_py(py),
ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))?.into_py(py),
ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base))?.into_py(py),
ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base))?.into_py(py),
Expand Down Expand Up @@ -560,6 +561,145 @@ impl PyBPE {
}
}

#[pyclass(module = "bpe")]
struct PyBacktrackingBpe {}

#[pymethods]
impl PyBacktrackingBpe {
#[getter]
fn get_dropout(self_: PyRef<Self>) -> Option<f32> {
getter!(self_, BPE, dropout)
}

#[setter]
fn set_dropout(self_: PyRef<Self>, dropout: Option<f32>) {
setter!(self_, BPE, dropout, dropout);
}

#[getter]
fn get_unk_token(self_: PyRef<Self>) -> Option<String> {
getter!(self_, BPE, unk_token.clone())
}

#[setter]
fn set_unk_token(self_: PyRef<Self>, unk_token: Option<String>) {
setter!(self_, BPE, unk_token, unk_token);
}
#[new]
#[pyo3(
signature = (vocab=None, merges=None, **kwargs),
text_signature = "(self, vocab=None, merges=None, dropout=None, unk_token=None)")]
fn new(
py: Python<'_>,
vocab: Option<PyVocab>,
merges: Option<PyMerges>,
kwargs: Option<&Bound<'_, PyDict>>,
) -> PyResult<(Self, PyModel)> {
if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) {
return Err(exceptions::PyValueError::new_err(
"`vocab` and `merges` must be both specified",
));
}

let mut builder = BPE::builder();
if let (Some(vocab), Some(merges)) = (vocab, merges) {
match (vocab, merges) {
(PyVocab::Vocab(vocab), PyMerges::Merges(merges)) => {
builder = builder.vocab_and_merges(vocab, merges);
}
(PyVocab::Filename(vocab_filename), PyMerges::Filename(merges_filename)) => {
deprecation_warning(
py,
"0.9.0",
"BPE.__init__ will not create from files anymore, try `BPE.from_file` instead",
)?;
builder =
builder.files(vocab_filename.to_string(), merges_filename.to_string());
}
_ => {
return Err(exceptions::PyValueError::new_err(
"`vocab` and `merges` must be both be from memory or both filenames",
));
}
}
}

PyBPE::with_builder(builder, kwargs)
}

/// Read a :obj:`vocab.json` and a :obj:`merges.txt` files
///
/// This method provides a way to read and parse the content of these files,
/// returning the relevant data structures. If you want to instantiate some BPE models
/// from memory, this method gives you the expected input from the standard files.
///
/// Args:
/// vocab (:obj:`str`):
/// The path to a :obj:`vocab.json` file
///
/// merges (:obj:`str`):
/// The path to a :obj:`merges.txt` file
///
/// Returns:
/// A :obj:`Tuple` with the vocab and the merges:
/// The vocabulary and merges loaded into memory
#[staticmethod]
#[pyo3(text_signature = "(self, vocab, merges)")]
fn read_file(vocab: &str, merges: &str) -> PyResult<(Vocab, Merges)> {
BPE::read_file(vocab, merges).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while reading vocab & merges files: {}",
e
))
})
}

/// Instantiate a BPE model from the given files.
///
/// This method is roughly equivalent to doing::
///
/// vocab, merges = BPE.read_file(vocab_filename, merges_filename)
/// bpe = BPE(vocab, merges)
///
/// If you don't need to keep the :obj:`vocab, merges` values lying around,
/// this method is more optimized than manually calling
/// :meth:`~tokenizers.models.BPE.read_file` to initialize a :class:`~tokenizers.models.BPE`
///
/// Args:
/// vocab (:obj:`str`):
/// The path to a :obj:`vocab.json` file
///
/// merges (:obj:`str`):
/// The path to a :obj:`merges.txt` file
///
/// Returns:
/// :class:`~tokenizers.models.BPE`: An instance of BPE loaded from these files
#[classmethod]
#[pyo3(signature = (vocab, merges, **kwargs))]
#[pyo3(text_signature = "(cls, vocab, merge, **kwargs)")]
fn from_file(
_cls: &Bound<'_, PyType>,
py: Python,
vocab: &str,
merges: &str,
kwargs: Option<&Bound<'_, PyDict>>,
) -> PyResult<Py<Self>> {
let (vocab, merges) = BPE::read_file(vocab, merges).map_err(|e| {
exceptions::PyException::new_err(format!("Error while reading BPE files: {}", e))
})?;
Py::new(
py,
PyBPE::new(
py,
Some(PyVocab::Vocab(vocab)),
Some(PyMerges::Merges(merges)),
kwargs,
)?,
)
}
}


/// An implementation of the WordPiece algorithm
///
/// Args:
Expand Down
2 changes: 2 additions & 0 deletions tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ fancy-regex = { version = "0.13", optional = true}
getrandom = { version = "0.2.10" }
esaxx-rs = { version = "0.1.10", default-features = false, features=[]}
monostate = "0.1.12"
fnv = "1.0.7"
aneubeck-daachorse = "1.1.1"

[features]
default = ["progressbar", "onig", "esaxx_fast"]
Expand Down
74 changes: 65 additions & 9 deletions tokenizers/benches/llama3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,81 @@
extern crate criterion;

use criterion::{Criterion, Throughput};
use itertools::Itertools;
use tokenizers::models::backtracking_bpe;
use tokenizers::PreTokenizerWrapper;
use tokenizers::Tokenizer;

pub fn llama3(c: &mut Criterion) {
let data = std::fs::read_to_string("data/big.txt").unwrap();
let mut group = c.benchmark_group("llama3-encode");
group.throughput(Throughput::Bytes(data.bytes().len() as u64));
group.bench_function("llama3-offsets", |b| {
let tokenizer =
Tokenizer::from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", None).unwrap();

group.bench_function("llama3-backtracking", |b| {
let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap();
let mut vocab = &mut tokenizer
.get_vocab(false)
.clone()
.into_iter()
.collect::<Vec<_>>(); // Convert HashMap into a Vec of (String, u32) tuples
//
vocab.sort_by(|a, b| a.1.cmp(&b.1));
vocab.truncate(vocab.len().saturating_sub(3));
let vocab: Vec<_> = vocab // Sort by u32 value
.into_iter() // IntoIterator to get the iterator of Vec<u8>
.map(|(tok, _)| Vec::from(tok.as_bytes()))
.collect();
let model: backtracking_bpe::BacktrackingBpe =
backtracking_bpe::BacktrackingBpe::from_dictionary(vocab, None, None);
tokenizer.with_model(model);
let data: Vec<_> = data.lines().collect();
let add_special_tokens = false;
b.iter(|| {
tokenizer
.encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens)
.unwrap()
})
});

group.bench_function("llama3-backtracking-no-pretok", |b| {
let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap();
let mut vocab = &mut tokenizer
.get_vocab(false)
.clone()
.into_iter()
.collect::<Vec<_>>(); // Convert HashMap into a Vec of (String, u32) tuples
//
vocab.sort_by(|a, b| a.1.cmp(&b.1));
vocab.truncate(vocab.len().saturating_sub(3));
let vocab: Vec<_> = vocab // Sort by u32 value
.into_iter() // IntoIterator to get the iterator of Vec<u8>
.map(|(tok, _)| Vec::from(tok.as_bytes()))
.collect();
let model: backtracking_bpe::BacktrackingBpe =
backtracking_bpe::BacktrackingBpe::from_dictionary(vocab, None, None);
tokenizer.with_model(model);
tokenizer.with_pre_tokenizer(None::<PreTokenizerWrapper>);
let data: Vec<_> = data.lines().collect();
let add_special_tokens = false;
b.iter(|| {
tokenizer
.encode_batch_char_offsets(criterion::black_box(data.clone()), add_special_tokens)
.encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens)
.unwrap()
})
});
group.bench_function("llama3-nooffsets", |b| {
let tokenizer =
Tokenizer::from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", None).unwrap();

group.bench_function("llama3-encode_batch_fast", |b| {
let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap();
let data: Vec<_> = data.lines().collect();
let add_special_tokens = false;
b.iter(|| {
tokenizer
.encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens)
.unwrap()
})
});
group.bench_function("llama3-encode_batch", |b| {
let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap();
let data: Vec<_> = data.lines().collect();
let add_special_tokens = false;
b.iter(|| {
Expand All @@ -30,13 +85,14 @@ pub fn llama3(c: &mut Criterion) {
.unwrap()
})
});

group.finish();
}

criterion_group! {
name = bert_benches;
name = llama;
config = Criterion::default().sample_size(10);
targets = llama3
}

criterion_main!(bert_benches);
criterion_main!(llama);
49 changes: 49 additions & 0 deletions tokenizers/src/models/backtracking_bpe/backtracking_state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use super::bitfield::BitField;

/// This can be thought of as a lazy variation of the dynamic programming approach.
/// It only computes those states which have to be visited in order to compute the tokenization
/// for a given input text.
/// It keeps track of visited states in a bitfield and only remembers the tokenization
/// of the currently processed dynamic programming state.
///
/// The biggest downside of this approach is that the search for the longest leftmost match (the firt token?)
/// has to be reset at every (backtracking) step which is still a net win in practice compared to other approaches.
#[derive(Clone, PartialEq)]
pub struct BacktrackState<'a> {
pub(crate) text: &'a [u8],
pub(crate) tokens: Vec<u32>, // len of the tezt / 3
pub(crate) next_token: Option<u32>, // bpe.next_match(text) wich is longest_searcher.leftmost_find_iter(text)'s first match value
pub(crate) pos: usize, // current pos in the text?
pub(crate) bitfield: BitField, // keeps track of token boundaries? keeps track of all the valid tokenization positions and making the runtime linear in the input length.
}

impl<'a> BacktrackState<'a> {
pub(crate) fn new(text: &'a [u8], next_token: Option<u32>) -> Self {
Self::with_capacity(text, next_token, text.len() / 3)
}

pub(crate) fn with_capacity(text: &'a [u8], next_token: Option<u32>, cap: usize) -> Self {
Self {
text,
tokens: Vec::with_capacity(cap),
next_token,
pos: 0,
bitfield: BitField::new(text.len() + 1),
}
}
pub(crate) fn count(&self) -> usize {
self.tokens.len()
}

pub(crate) fn pos(&self) -> usize {
self.pos
}

pub(crate) fn last_token(&self) -> Option<u32> {
self.tokens.last().copied()
}

pub(crate) fn into_tokens(self) -> Vec<u32> {
self.tokens
}
}
57 changes: 57 additions & 0 deletions tokenizers/src/models/backtracking_bpe/bitfield.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/// Small helper to manage a bit field which supports predecessor and successor queries with a simple scan implementation.
/// This is sufficient for our use case, since two one bits will be at most 128 bits apart.
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct BitField {
bitfield: Vec<u64>,
}

impl BitField {
/// All bits are initialized to 1.
pub(crate) fn new(bits: usize) -> Self {
Self {
bitfield: vec![u64::MAX; (bits + 63) / 64],
}
}

pub(crate) fn is_set(&self, bit: usize) -> bool {
let (word, bit) = (bit / 64, bit % 64);
self.bitfield[word] & (1 << bit) != 0
}

pub(crate) fn clear(&mut self, bit: usize) {
let (word, bit) = (bit / 64, bit % 64);
self.bitfield[word] &= !(1 << bit);
}

pub(crate) fn successor(&self, bit: usize) -> usize {
let (mut word_idx, bit_idx) = (bit / 64, bit % 64);
let word = self.bitfield[word_idx] >> bit_idx;
if word != 0 {
word.trailing_zeros() as usize + bit
} else {
loop {
word_idx += 1;
let word = self.bitfield[word_idx];
if word != 0 {
break word.trailing_zeros() as usize + word_idx * 64;
}
}
}
}

pub(crate) fn predecessor(&self, bit: usize) -> usize {
let (mut word_idx, bit_idx) = (bit / 64, bit % 64);
let word = self.bitfield[word_idx] << (63 - bit_idx);
if word != 0 {
bit - word.leading_zeros() as usize
} else {
loop {
word_idx -= 1;
let word = self.bitfield[word_idx];
if word != 0 {
break word_idx * 64 + 63 - word.leading_zeros() as usize;
}
}
}
}
}
8 changes: 8 additions & 0 deletions tokenizers/src/models/backtracking_bpe/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
mod backtracking_state;
mod bitfield;
mod model;
mod serialization;
pub mod trainer;

pub use model::*;
pub use trainer::*;
Loading
Loading