diff --git a/benchmarks/train.py b/benchmarks/train.py index c86eff7..af4635b 100644 --- a/benchmarks/train.py +++ b/benchmarks/train.py @@ -19,32 +19,19 @@ @dataclasses.dataclass class TrainBPETokenizerArgs: - english_datasets: str = ( - "/Users/gautier/Github/tokenizer-benchmarks/data/english/test" - ) - code_datasets: str = "/Users/gautier/Github/tokenizer-benchmarks/data/code/test" - multilingual_datasets: str = ( - "/Users/gautier/Github/tokenizer-benchmarks/data/multilingual/test" - ) + datasets: str = "./benchmarks/data" num_characters: int = 1000 vocab_size: int = 1024 max_sentencepiece_length: int = 32 normalization_rule_name: str = "gpt" - code_percentage: float = 0.4 - multilingual_percentage: float = 0.3 def __post_init__(self): - datasets = ( - self.english_datasets.split(",") - + self.code_datasets.split(",") - + self.multilingual_datasets.split(",") - ) + datasets = self.datasets.split(",") for ckpt in datasets: checkpoint_dir = Path(ckpt) assert checkpoint_dir.is_dir(), checkpoint_dir - assert self.code_percentage + self.multilingual_percentage <= 1 assert self.normalization_rule_name in [ "gpt", "gpt-num2", @@ -103,24 +90,9 @@ def jsonl_content_iterator( def mix_jsonl_content_iterator(args: TrainBPETokenizerArgs): datasets = [] - code_datasets = args.code_datasets.split(",") - mp_datasets = args.multilingual_datasets.split(",") - en_datasets = args.english_datasets.split(",") - for dataset in code_datasets: - if args.code_percentage > 0: - datasets.append((dataset, args.code_percentage / len(code_datasets))) - for dataset in mp_datasets: - if args.multilingual_percentage > 0: - datasets.append((dataset, args.multilingual_percentage / len(mp_datasets))) - for dataset in en_datasets: - if (1 - args.code_percentage - args.multilingual_percentage) > 0: - datasets.append( - ( - dataset, - (1 - args.code_percentage - args.multilingual_percentage) - / len(en_datasets), - ) - ) + num_datasets = len(args.datasets.split(",")) + for dataset in args.datasets.split(","): + datasets.append((dataset, args.code_percentage / num_datasets)) # Create iterators iterators = [] diff --git a/src/lib.rs b/src/lib.rs index 095ac29..d77695c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ use pyo3::types::{PyBytes, PyDict, PyIterator, PyString}; use rayon::prelude::*; use std::cmp::Ordering; use std::collections::{BinaryHeap, HashMap, HashSet}; +use std::sync::Arc; type Pair = (u32, u32); @@ -122,30 +123,56 @@ impl Sentence { fn get_symbols(&self) -> Vec { self.symbols.iter().map(|s| s.c).collect() } -} - -fn pretokenize(text: &str, pattern: &str) -> Vec { - let regex = Regex::new(pattern); - let mut pretokenized_sentences: Vec = Vec::new(); + fn from_str(s: String) -> Self { + let mut sentence = Sentence::new(); + for byte in s.bytes() { + sentence.add(byte as u32, 1); + } + sentence + } +} - for match_result in regex.expect(pattern).find_iter(text) { - match match_result { - Ok(token) => { - let mut sentence: Sentence = Sentence::new(); - for byte in token.as_str().bytes() { - // tokenized_byte.push(byte as u32); - sentence.add(byte as u32, 1); - } - pretokenized_sentences.push(sentence); - } +fn pretokenize(text: &str, regex: Arc) -> Vec { + regex + .find_iter(text) + .map(|mat| match mat { + Ok(token) => token.as_str().to_string(), Err(e) => { println!("Error: {:?}", e); - break; + "".to_string() } - } - } - pretokenized_sentences + }) + .collect() +} + +fn pretokenize_strings(strings: Vec<&str>, pattern: &str) -> (Vec, Vec) { + let regex = Arc::new(Regex::new(pattern).expect("Invalid regex pattern")); + let (tokens, counts): (Vec, Vec) = strings + .par_iter() + .filter(|text| !text.is_empty()) + .flat_map(|&text| pretokenize(text, Arc::clone(®ex))) + .fold( + || HashMap::new(), + |mut acc, token| { + *acc.entry(token).or_insert(0) += 1; + acc + }, + ) + .reduce( + || HashMap::new(), + |mut a, b| { + for (token, count) in b { + *a.entry(token).or_insert(0) += count; + } + a + }, + ) + .into_iter() + .unzip(); + + let sentences: Vec = tokens.into_iter().map(Sentence::from_str).collect(); + (sentences, counts) } fn initialize_vocab_bytes(vocab_size: usize) -> (HashMap, u32>, Vec>) { @@ -159,7 +186,8 @@ fn initialize_vocab_bytes(vocab_size: usize) -> (HashMap, u32>, Vec, + tokenized_sentences: &[Sentence], + base_counts: &[u64], ) -> (HashMap, HashMap>) { // Calculate frequencies for each pair of bytes in all sentences and words return tokenized_sentences @@ -167,14 +195,13 @@ fn get_most_frequent_pair( .enumerate() .map(|(i, sentence)| { let mut local_pair_counts = HashMap::new(); - let mut local_pair_positions = HashMap::new(); - for word in sentence.get_symbols().windows(2) { - let current_pair: Pair = (word[0], word[1]); + let mut local_pair_positions: HashMap> = HashMap::new(); - // Initialize pair_counts for this pair if we just saw it for the first time + for window in sentence.get_symbols().windows(2) { + let current_pair: Pair = (window[0], window[1]); local_pair_counts .entry(current_pair) - .and_modify(|c| *c += 1) + .and_modify(|c| *c += base_counts[i] as i64) .or_insert(1); // Then update position @@ -215,12 +242,17 @@ fn get_most_frequent_pair( // Build vocab from most frequent pairs fn build_bpe_vocab( tokenized_sentences: Vec, + base_counts: &[u64], max_token_length: usize, vocab_size: usize, ) -> HashMap, u32> { + let time = std::time::Instant::now(); let (mut word_to_id, mut id_to_word) = initialize_vocab_bytes(vocab_size); + let (mut global_pair_counts, mut global_pair_positions) = - get_most_frequent_pair(&tokenized_sentences); + get_most_frequent_pair(&tokenized_sentences, &base_counts); + + println!("Time to get most frequent pair: {:?}", time.elapsed()); // build Priority Queue from counts and positions let mut queue: BinaryHeap = BinaryHeap::new(); @@ -279,11 +311,13 @@ fn build_bpe_vocab( .collect::>(); for ((pair, change), iw) in changes { + // adjust count to reflect sentence level count + let count = change * base_counts[iw] as i64; global_pair_counts .entry(pair) - .and_modify(|c| *c += change) - .or_insert(change); - if change > 0 { + .and_modify(|c| *c += count) + .or_insert(count); + if count > 0 { global_pair_positions .entry(pair) .and_modify(|h| { @@ -323,6 +357,7 @@ fn train_bpe( let num_threads = rayon::current_num_threads(); println!("Number of threads: {}", num_threads); + let time = std::time::Instant::now(); // validate inputs if max_token_length < 2 { return Err(exceptions::PyValueError::new_err( @@ -348,21 +383,23 @@ fn train_bpe( }) }) .collect(); - - let pretokenized_sentences: Vec = strings - .par_iter() - .filter(|text| !text.is_empty()) // Filter out empty strings - .map(|text| pretokenize(text, ®ex)) // Tokenize non-empty strings - .reduce( - || Vec::new(), - |mut acc, sentences| { - acc.extend(sentences); - acc - }, - ); + println!("Time to get strings {:?}", time.elapsed()); + println!("Number of strings: {}", strings.len()); + let (pretokenized_sentences, counts): (Vec, Vec) = + pretokenize_strings(strings, regex); + println!( + "Number of pretokenized_sentences: {}", + pretokenized_sentences.len() + ); + println!("Time to get pretokenize {:?}", time.elapsed()); println!("Done tokenizing"); - let bpe_vocab = build_bpe_vocab(pretokenized_sentences, max_token_length, vocab_size); + let bpe_vocab = build_bpe_vocab( + pretokenized_sentences, + &counts, + max_token_length, + vocab_size, + ); let python_dict_out = PyDict::new(py); // convert bpe_vocab to python dict for (key, value) in bpe_vocab { @@ -386,13 +423,28 @@ mod tests { #[test] fn test_all() { let text: &str = "\tYou hear £ £ £ here"; - let regex = r"([^\s]+)|(\s+)"; - let pretokenized_sentences = crate::pretokenize(text, regex); + let pattern = r"([^\s]+)|(\s+)"; + use fancy_regex::Regex; + use std::sync::Arc; + + let compiled_regex = Arc::new(Regex::new(pattern).expect("Invalid regex pattern")); + + let pretokenized_sentences = crate::pretokenize(text, compiled_regex); println!("{:?}", pretokenized_sentences); + let text_2: &str = "You hear £ £ £ here"; + + let (pretokenized_sentences, _counts) = + crate::pretokenize_strings(vec![text, text_2], pattern); + let vocab_size = 300; let max_token_length = 128; - crate::build_bpe_vocab(pretokenized_sentences, max_token_length, vocab_size); + crate::build_bpe_vocab( + pretokenized_sentences, + &_counts, + max_token_length, + vocab_size, + ); } #[test] diff --git a/train.py b/train.py index bf53907..2bc95e1 100644 --- a/train.py +++ b/train.py @@ -209,12 +209,16 @@ def train(args: TrainBPETokenizerArgs): iterator = mix_jsonl_content_iterator(args) # training the tokenizer regex = get_regex_from_normalization_rule_name(args.normalization_rule_name) + import time + time_now = time.time() vocab = bpeasy.train_bpe( iterator, regex, args.max_sentencepiece_length, args.vocab_size, ) + logging.info(f"Training took {time.time() - time_now} seconds") + name = generate_model_name(asdict(args)) bpeasy.save_vocab_to_tiktoken(