Skip to content

Commit

Permalink
test optims
Browse files Browse the repository at this point in the history
  • Loading branch information
gautierdag committed Nov 15, 2023
1 parent 5e1b8d9 commit 6dac3c7
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
cargo test
maturin develop
pytest tests --cov=foobar --cov-report xml
cargo llvm-cov --no-run --lcov --output-path coverage.lcov
cargo llvm-cov report --lcov --output-path coverage.lcov
- uses: codecov/codecov-action@v3
with:
files: coverage.lcov,coverage.xml
Expand Down
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ crate-type = ["cdylib"]
[dependencies]
fancy-regex = "0.12.0"
pyo3 = { version = "0.19.0", features = ["extension-module"] }
regex = "1.5.4"
rayon = "1.8.0"
regex = "1.5.4"
2 changes: 2 additions & 0 deletions benchmarks/train_bpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import bpeasy as bpeasy
import tokenizers
1 change: 1 addition & 0 deletions python/bpeasy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .bpeasy import *
8 changes: 8 additions & 0 deletions python/bpeasy/bpeasy.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Iterator

def train_bpe(
iterator: Iterator[str],
python_regex: str,
max_token_length: int,
vocab_size: int,
) -> dict[bytes, int]: ...
Empty file added python/bpeasy/py.typed
Empty file.
9 changes: 9 additions & 0 deletions python/bpeasy/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import json


def load_json(path):
with open(path, "r") as f:
return json.load(f)


def get
99 changes: 99 additions & 0 deletions src/iterator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Original source: https://github.com/huggingface/tokenizers/blob/e3bcef288b5309a19556c02a16a9c58a52197b76/bindings/python/src/utils/iterators.rs
use pyo3::prelude::*;
use pyo3::AsPyPointer;
use std::collections::VecDeque;

/// A buffered iterator that takes care of locking the GIL only when needed.
/// The `PyIterator` provided by PyO3 keeps a Python GIL token all along
/// and thus doesn't allow us to release the GIL to allow having other threads.
///
/// This iterator serves two purposes:
/// - First, as opposed to the `pyo3::PyIterator`, it is Send and can easily be parallelized
/// - Second, this let us release the GIL between two refills of the buffer, allowing other
/// Python threads to work
pub struct PyBufferedIterator<T, F> {
iter: Option<Py<PyAny>>,
converter: F,
buffer: VecDeque<PyResult<T>>,
size: usize,
}

impl<T, F, I> PyBufferedIterator<T, F>
where
F: Fn(&PyAny) -> I,
I: IntoIterator<Item = PyResult<T>>,
{
/// Create a new PyBufferedIterator using the provided Python object.
/// This object must implement the Python Iterator Protocol, and an error will
/// be return if the contract is not respected.
///
/// The `converter` provides a way to convert each item in the iterator into
/// something that doesn't embed a 'py token and thus allows the GIL to be released
///
/// The `buffer_size` represents the number of items that we buffer before we
/// need to acquire the GIL again.
pub fn new(iter: &PyAny, converter: F, buffer_size: usize) -> PyResult<Self> {
let py = iter.py();
let iter: Py<PyAny> = unsafe {
py.from_borrowed_ptr_or_err::<PyAny>(pyo3::ffi::PyObject_GetIter(iter.as_ptr()))?
.to_object(py)
};

Ok(Self {
iter: Some(iter),
converter,
buffer: VecDeque::with_capacity(buffer_size),
size: buffer_size,
})
}

/// Refill the buffer, and set `self.iter` as `None` if nothing more to get
fn refill(&mut self) -> PyResult<()> {
if self.iter.is_none() {
return Ok(());
}
Python::with_gil(|py| loop {
if self.buffer.len() >= self.size {
return Ok(());
}
match unsafe {
py.from_owned_ptr_or_opt::<PyAny>(pyo3::ffi::PyIter_Next(
self.iter.as_ref().unwrap().as_ref(py).as_ptr(),
))
} {
Some(obj) => self.buffer.extend((self.converter)(obj)),
None => {
if PyErr::occurred(py) {
return Err(PyErr::fetch(py));
} else {
self.iter = None;
}
}
};
if self.iter.is_none() {
return Ok(());
}
})
}
}

impl<T, F, I> Iterator for PyBufferedIterator<T, F>
where
F: Fn(&PyAny) -> I,
I: IntoIterator<Item = PyResult<T>>,
{
type Item = PyResult<T>;

fn next(&mut self) -> Option<Self::Item> {
if !self.buffer.is_empty() {
self.buffer.pop_front()
} else if self.iter.is_some() {
if let Err(e) = self.refill() {
return Some(Err(e));
}
self.next()
} else {
None
}
}
}
67 changes: 62 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use fancy_regex::Regex;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyIterator, PyString};
use rayon::prelude::*;
use std::collections::HashMap;
use std::sync::Mutex;

fn tokenize(text: &str, pattern: &str) -> Vec<Vec<Vec<u8>>> {
let regex = Regex::new(pattern);
Expand Down Expand Up @@ -44,6 +46,8 @@ fn get_most_frequent_pair(
Return the most frequent pair of bytes
*/

// Calculate frequencies for each pair of bytes in all sentences and words
// uses mutex to allow parallel processing through rayon
let mut pair_freqs: HashMap<(Vec<u8>, Vec<u8>), u128> = HashMap::new();

// Calculate frequencies for each pair of bytes in all sentences and words
Expand All @@ -58,13 +62,36 @@ fn get_most_frequent_pair(
}
}
}
// println!("{:?}", pair_freqs);

// let pair_freqs: Mutex<HashMap<(Vec<u8>, Vec<u8>), u128>> = Mutex::new(HashMap::new());
// tokenized_bytes.par_iter().for_each(|sentence| {
// let mut local_freqs = HashMap::new();
// for word in sentence.windows(2) {
// if word[0].len() + word[1].len() > max_token_length {
// continue;
// }
// if let [a, b] = word {
// *local_freqs.entry((a.to_vec(), b.to_vec())).or_insert(0) += 1;
// }
// }

// let mut global_freqs = pair_freqs.lock().unwrap();
// for (pair, count) in local_freqs {
// *global_freqs.entry(pair).or_insert(0) += count;
// }
// });
// let pair_freqs = pair_freqs.into_inner().unwrap();
let most_frequent_pair = pair_freqs.iter().max_by_key(|&(_, count)| count);
println!("Most frequent pair: {:?}", most_frequent_pair);
if most_frequent_pair.is_none() {
return None;
}
let ((ref left, ref right), _count) = most_frequent_pair.unwrap();

println!(
"Most frequent pair: {:?} and count {}",
(left, right),
_count
);
Some((left.clone(), right.clone()))
}

Expand Down Expand Up @@ -131,14 +158,17 @@ fn build_bpe_vocab(
#[pyfunction]
fn train_bpe(
py: Python,
iterator: PyObject,
iterator: &PyIterator,
python_regex: &PyString,
max_token_length: usize,
vocab_size: usize,
) -> PyResult<PyObject> {
let iterator = PyIterator::from_object(py, &iterator)?;
let regex = python_regex.to_str()?;

println!("STARTING BPEasy training");
let num_threads = rayon::current_num_threads();
println!("Number of threads: {}", num_threads);

// validate inputs
if max_token_length < 2 {
return Err(exceptions::PyValueError::new_err(
Expand All @@ -154,8 +184,34 @@ fn train_bpe(
return Err(exceptions::PyValueError::new_err("regex cannot be empty"));
}

let mut tokenized_bytes: Vec<Vec<Vec<u8>>> = Vec::new();
// let mut tokenized_bytes: Vec<Vec<Vec<u8>>> = Vec::new();
// let tokenized_bytes = Mutex::new(Vec::new());

// Extract strings from Python iterator and store them in a Rust Vec for parallel processing
// let strings: Vec<&str> = iterator
// .filter_map(|item_result| {
// item_result.ok().and_then(|item| {
// item.extract::<&PyString>()
// .ok()
// .and_then(|py_string| py_string.to_str().ok())
// })
// })
// .collect();

// // split all text into tokens
// strings.par_iter().for_each(|text| {
// if !text.is_empty() {
// // println!("Text: {:?}", text);
// let tokens_bytes = tokenize(text, regex);
// // Lock the mutex and extend the vector
// let mut tokenized_bytes_lock = tokenized_bytes.lock().unwrap();
// tokenized_bytes_lock.extend(tokens_bytes);
// }
// });

// let tokenized_bytes = tokenized_bytes.into_inner().unwrap();

let mut tokenized_bytes: Vec<Vec<Vec<u8>>> = Vec::new();
// split all text into tokens
for item in iterator {
let item: &PyString = item?.extract()?;
Expand All @@ -167,6 +223,7 @@ fn train_bpe(
tokenized_bytes.extend(tokens_bytes);
}

println!("Done tokenizing");
let bpe_vocab = build_bpe_vocab(tokenized_bytes, max_token_length, vocab_size);
let python_dict_out = PyDict::new(py);

Expand Down
20 changes: 10 additions & 10 deletions tests/test_train_bpe.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import bpeasy
from bpeasy import train_bpe


def test_train_bpe_vocab_size():
vocab_size = 10
max_token_length = 4
regex = r"([^\s]+)|(\s+)"
vocab = bpeasy.train_bpe(
["This is a test", "this is another test", "good tests"],
vocab = train_bpe(
iter(["This is a test", "this is another test", "good tests"]),
regex,
max_token_length,
vocab_size,
Expand All @@ -18,17 +18,17 @@ def test_train_bpe_max_token_length():
vocab_size = 5
max_token_length = 2
regex = r"([^\s]+)|(\s+)"
vocab = bpeasy.train_bpe(
["This is a test", "this is another test", "good tests"],
vocab = train_bpe(
iter(["This is a test", "this is another test", "good tests"]),
regex,
max_token_length,
vocab_size,
)
for token in vocab:
assert len(token) <= max_token_length
max_token_length = 3
vocab = bpeasy.train_bpe(
["This is a test", "this is another test", "good tests"],
vocab = train_bpe(
iter(["This is a test", "this is another test", "good tests"]),
regex,
max_token_length,
vocab_size,
Expand All @@ -38,11 +38,11 @@ def test_train_bpe_max_token_length():


def test_train_bpe_gpt_regex():
vocab_size = 20
vocab_size = 30
max_token_length = 128
regex = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
vocab = bpeasy.train_bpe(
["We've got a test", "We've got another test", "this is a good tests"],
vocab = train_bpe(
iter(["We've got a test", "We've got good test", "this is a good tests"]),
regex,
max_token_length,
vocab_size,
Expand Down

0 comments on commit 6dac3c7

Please sign in to comment.