Skip to content

Commit

Permalink
maturin working
Browse files Browse the repository at this point in the history
  • Loading branch information
gautierdag committed Nov 13, 2023
1 parent 669cfd1 commit ca6acc0
Show file tree
Hide file tree
Showing 6 changed files with 814 additions and 2 deletions.
7 changes: 6 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ name = "bpeasy"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "bpeasy"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.19.0", features = ["extension-module"] }
regex = "1.5.4"

[[bin]]
name = "bpeasy"
path = "main.rs"


2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# bpeasy
# bpeasy
16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[build-system]
requires = ["maturin>=1.3,<2.0"]
build-backend = "maturin"

[project]
name = "bpeasy"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dynamic = ["version"]

[tool.maturin]
features = ["pyo3/extension-module"]
16 changes: 16 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use pyo3::prelude::*;

/// Formats the sum of two numbers as string.
#[pyfunction]
fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
Ok((a + b).to_string())
}

/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule]
fn bpeasy(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(sum_as_string, m)?)?;
Ok(())
}
167 changes: 167 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
extern crate regex;
use pyo3::prelude::*;
use regex::Regex;
use std::collections::HashMap;

/// Formats the sum of two numbers as string.
#[pyfunction]
fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
Ok((a + b).to_string())
}

/// A Python module implemented in Rust.
#[pymodule]
fn bpeasy(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(sum_as_string, m)?)?;
Ok(())
}

fn tokenize(text: &str) -> Vec<String> {
// regex splits
let re = Regex::new(r"([^\s]+)|(\s+)").unwrap();
re.find_iter(text)
.map(|mat| mat.as_str().to_string())
.collect()
}

fn convert_to_tokenized_bytes(tokenized_text: Vec<String>) -> Vec<Vec<Vec<u8>>> {
let mut tokenized_bytes: Vec<Vec<Vec<u8>>> = Vec::new();
for token in tokenized_text {
let mut tokenized_byte: Vec<Vec<u8>> = Vec::new();
for byte in token.bytes() {
tokenized_byte.push(vec![byte]);
}
tokenized_bytes.push(tokenized_byte);
}
tokenized_bytes
}

fn initialize_vocab_bytes() -> HashMap<Vec<u8>, u64> {
let mut vocab: HashMap<Vec<u8>, u64> = HashMap::new();
for i in 0..255 {
vocab.insert(vec![i], i as u64);
}
vocab
}

fn get_most_frequent_pair(tokenized_bytes: &mut Vec<Vec<Vec<u8>>>) -> Option<(Vec<u8>, Vec<u8>)> {
/*
Calculate frequencies for each pair of bytes in all sentences and words
Return the most frequent pair of bytes
*/

let mut pair_freqs: HashMap<(Vec<u8>, Vec<u8>), u128> = HashMap::new();

// Calculate frequencies for each pair of bytes in all sentences and words
for sentence in tokenized_bytes {
for word in sentence.windows(2) {
if let [a, b] = word {
*pair_freqs.entry((a.to_vec(), b.to_vec())).or_insert(0) += 1;
}
}
}
// println!("{:?}", pair_freqs);
let most_frequent_pair = pair_freqs.iter().max_by_key(|&(_, count)| count);
println!("Most frequent pair: {:?}", most_frequent_pair);
if most_frequent_pair.is_none() {
return None;
}
let ((ref left, ref right), _count) = most_frequent_pair.unwrap();
Some((left.clone(), right.clone()))
}

fn merge_frequent_pair(tokenized_bytes: &mut Vec<Vec<Vec<u8>>>, left: Vec<u8>, right: Vec<u8>) {
// Merge the most frequent pair in all sentences and words
for sentence in tokenized_bytes.iter_mut() {
let mut i = 0;
while i < sentence.len() - 1 {
// Check if the current and next token form the most frequent pair
if sentence[i] == left.clone() && sentence[i + 1] == right.clone() {
// Merge the pair and replace the first element with the merged pair
let merged = [&sentence[i][..], &sentence[i + 1][..]].concat();
sentence[i] = merged;
// Remove the second element of the pair
sentence.remove(i + 1);
// Do not increment i, as we want to check the next pair starting from the current position
} else {
i += 1; // Move to the next token
}
}
}
}

// fn print_vocab_bytes(vocab: &HashMap<Vec<u8>, u64>) {
// // sort by value
// let mut sorted_vocab: Vec<_> = vocab.iter().collect();
// sorted_vocab.sort_by(|a, b| a.1.cmp(b.1));
// for (key, value) in sorted_vocab {
// // try to convert to string
// let key_str = String::from_utf8_lossy(key);
// println!("{:?}: {}", key_str, value);
// }
// }

fn build_bpe_vocab(
mut tokenized_bytes: Vec<Vec<Vec<u8>>>,
vocab_size: usize,
) -> HashMap<Vec<u8>, u64> {
let mut vocab: HashMap<Vec<u8>, u64> = initialize_vocab_bytes();

println!("{:?}", vocab);

let mut num_token_added = 0;
while num_token_added < vocab_size {
println!("Iteration: {}", num_token_added);

let most_frequent_pair = get_most_frequent_pair(&mut tokenized_bytes);
if most_frequent_pair.is_none() {
break;
}
let (left, right) = most_frequent_pair.unwrap();

// Merge the most frequent pair in all sentences and words
merge_frequent_pair(&mut tokenized_bytes, left.clone(), right.clone());

let mut token = left.clone(); // Clone the first token
token.extend(right); // Extend with the second token
// Now, combined_token contains the merged pair
println!("Combined token: {:?}", token);

// combine pair into a single token
let token_str = String::from_utf8_lossy(&token);
println!("Token added: {:?}", token_str);
vocab.insert(token, vocab.len() as u64);

num_token_added += 1;
}
// print_vocab_bytes(&vocab);
vocab
}

#[cfg(test)]

mod tests {
use super::*;

#[test]
fn test_tokenize() {
let text = "Your text data here";
let tokens = tokenize(text);
assert_eq!(tokens, vec!["Your", " ", "text", " ", "data", " ", "here"]);
}

#[test]
fn test_all() {
let text: &str = "\tYou hear £ £ £ here";

let tokens = tokenize(text);
println!("{:?}", tokens);
let tokenized_bytes = convert_to_tokenized_bytes(tokens);
println!("{:?}", tokenized_bytes);

let vocab_size = 10;
let bpe_vocab = build_bpe_vocab(tokenized_bytes, vocab_size);
println!("{:?}", bpe_vocab);
// Output or use the encoded text
}
}
Loading

0 comments on commit ca6acc0

Please sign in to comment.