Skip to content

Commit

Permalink
add typing stub, refactor tokenize
Browse files Browse the repository at this point in the history
  • Loading branch information
gautierdag committed Nov 14, 2023
1 parent eb481ac commit a147225
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 37 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ name = "bpeasy"
crate-type = ["cdylib"]

[dependencies]
fancy-regex = "0.12.0"
pyo3 = { version = "0.19.0", features = ["extension-module"] }
regex = "1.5.4"

Expand Down
8 changes: 8 additions & 0 deletions bpeasy.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Iterator

def train_bpe(
iterator: Iterator[str],
python_regex: str,
max_token_length: int,
vocab_size: int,
) -> dict[bytes, int]: ...
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ classifiers = [
dynamic = ["version"]

[tool.maturin]
features = ["pyo3/extension-module"]
features = ["pyo3/extension-module"]
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pytest>=7.1.2
pytest-cov>=3.0.0
maturin>=0.12.14
84 changes: 62 additions & 22 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
use fancy_regex::Regex;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyIterator, PyString};
extern crate regex;
use regex::Regex;
use std::collections::HashMap;

fn tokenize(text: &str, regex: &str) -> Vec<String> {
// regex splits
let re = Regex::new(regex).unwrap();
re.find_iter(text)
.map(|mat| mat.as_str().to_string())
.collect()
}
fn tokenize(text: &str, pattern: &str) -> Vec<Vec<Vec<u8>>> {
let regex = Regex::new(pattern);

fn convert_to_tokenized_bytes(tokenized_text: Vec<String>) -> Vec<Vec<Vec<u8>>> {
let mut tokenized_bytes: Vec<Vec<Vec<u8>>> = Vec::new();
for token in tokenized_text {
let mut tokenized_byte: Vec<Vec<u8>> = Vec::new();
for byte in token.bytes() {
tokenized_byte.push(vec![byte]);

for match_result in regex.expect(pattern).find_iter(text) {
match match_result {
Ok(token) => {
let mut tokenized_byte: Vec<Vec<u8>> = Vec::new();
for byte in token.as_str().bytes() {
tokenized_byte.push(vec![byte]);
}
tokenized_bytes.push(tokenized_byte);
}
Err(e) => {
println!("Error: {:?}", e);
break;
}
}
tokenized_bytes.push(tokenized_byte);
}
tokenized_bytes
}
Expand Down Expand Up @@ -134,14 +137,32 @@ fn train_bpe(
let iterator = PyIterator::from_object(py, &iterator)?;
let regex = python_regex.to_str()?;

// validate inputs
if max_token_length < 2 {
return Err(exceptions::PyValueError::new_err(
"max_token_length must be greater than 1",
));
}
if vocab_size < 1 {
return Err(exceptions::PyValueError::new_err(
"vocab_size must be greater than 0",
));
}
if regex.is_empty() {
return Err(exceptions::PyValueError::new_err("regex cannot be empty"));
}

let mut tokenized_bytes: Vec<Vec<Vec<u8>>> = Vec::new();

// split all text into tokens
for item in iterator {
let item: &PyString = item?.extract()?;
let text = item.to_str()?;
let tokens = tokenize(text, regex);
let tokens_bytes = convert_to_tokenized_bytes(tokens);
if text.is_empty() {
continue;
}

let tokens_bytes = tokenize(text, regex);
tokenized_bytes.extend(tokens_bytes);
}

Expand Down Expand Up @@ -173,19 +194,32 @@ mod tests {

#[test]
fn test_tokenize() {
let text = "Your text data here";
let text = "a b c";
let regex = r"([^\s]+)|(\s+)";
let tokens = tokenize(text, regex);
assert_eq!(tokens, vec!["Your", " ", "text", " ", "data", " ", "here"]);
// assert no error
// assert!(tokens.is_ok());

assert_eq!(
tokens,
vec![
vec![vec![97]],
vec![vec![32]],
vec![vec![98]],
vec![vec![32]],
vec![vec![99]]
]
);
}

#[test]
fn test_all() {
let text: &str = "\tYou hear £ £ £ here";
let regex = r"([^\s]+)|(\s+)";
let tokens = tokenize(text, regex);
println!("{:?}", tokens);
let tokenized_bytes = convert_to_tokenized_bytes(tokens);
// let tokens = tokenize(text, regex);
// println!("{:?}", tokens);
// let tokenized_bytes = convert_to_tokenized_bytes(tokens);
let tokenized_bytes = tokenize(text, regex);
println!("{:?}", tokenized_bytes);

let vocab_size = 10;
Expand All @@ -194,4 +228,10 @@ mod tests {
println!("{:?}", bpe_vocab);
// Output or use the encoded text
}

#[test]
fn test_initialize_vocab_bytes() {
let vocab = initialize_vocab_bytes();
assert_eq!(vocab.len(), 255);
}
}
54 changes: 54 additions & 0 deletions tests/test_train_bpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import bpeasy


def test_train_bpe_vocab_size():
vocab_size = 10
max_token_length = 4
regex = r"([^\s]+)|(\s+)"
vocab = bpeasy.train_bpe(
["This is a test", "this is another test", "good tests"],
regex,
max_token_length,
vocab_size,
)
assert len(vocab) == vocab_size + 255


def test_train_bpe_max_token_length():
vocab_size = 5
max_token_length = 2
regex = r"([^\s]+)|(\s+)"
vocab = bpeasy.train_bpe(
["This is a test", "this is another test", "good tests"],
regex,
max_token_length,
vocab_size,
)
for token in vocab:
assert len(token) <= max_token_length
max_token_length = 3
vocab = bpeasy.train_bpe(
["This is a test", "this is another test", "good tests"],
regex,
max_token_length,
vocab_size,
)
for token in vocab:
assert len(token) <= max_token_length


def test_train_bpe_gpt_regex():
vocab_size = 20
max_token_length = 128
regex = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
vocab = bpeasy.train_bpe(
["We've got a test", "We've got another test", "this is a good tests"],
regex,
max_token_length,
vocab_size,
)
for token in vocab:
assert len(token) <= max_token_length

assert b" go" in vocab.keys()
assert b"'ve" in vocab.keys()
14 changes: 0 additions & 14 deletions tests/train_test.py

This file was deleted.

0 comments on commit a147225

Please sign in to comment.