Skip to content

Commit

Permalink
add hashing tokenized chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
gautierdag committed Dec 12, 2023
1 parent 7923da0 commit 2996643
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 78 deletions.
38 changes: 5 additions & 33 deletions benchmarks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,19 @@

@dataclasses.dataclass
class TrainBPETokenizerArgs:
english_datasets: str = (
"/Users/gautier/Github/tokenizer-benchmarks/data/english/test"
)
code_datasets: str = "/Users/gautier/Github/tokenizer-benchmarks/data/code/test"
multilingual_datasets: str = (
"/Users/gautier/Github/tokenizer-benchmarks/data/multilingual/test"
)
datasets: str = "./benchmarks/data"

num_characters: int = 1000
vocab_size: int = 1024
max_sentencepiece_length: int = 32
normalization_rule_name: str = "gpt"
code_percentage: float = 0.4
multilingual_percentage: float = 0.3

def __post_init__(self):
datasets = (
self.english_datasets.split(",")
+ self.code_datasets.split(",")
+ self.multilingual_datasets.split(",")
)
datasets = self.datasets.split(",")
for ckpt in datasets:
checkpoint_dir = Path(ckpt)
assert checkpoint_dir.is_dir(), checkpoint_dir

assert self.code_percentage + self.multilingual_percentage <= 1
assert self.normalization_rule_name in [
"gpt",
"gpt-num2",
Expand Down Expand Up @@ -103,24 +90,9 @@ def jsonl_content_iterator(

def mix_jsonl_content_iterator(args: TrainBPETokenizerArgs):
datasets = []
code_datasets = args.code_datasets.split(",")
mp_datasets = args.multilingual_datasets.split(",")
en_datasets = args.english_datasets.split(",")
for dataset in code_datasets:
if args.code_percentage > 0:
datasets.append((dataset, args.code_percentage / len(code_datasets)))
for dataset in mp_datasets:
if args.multilingual_percentage > 0:
datasets.append((dataset, args.multilingual_percentage / len(mp_datasets)))
for dataset in en_datasets:
if (1 - args.code_percentage - args.multilingual_percentage) > 0:
datasets.append(
(
dataset,
(1 - args.code_percentage - args.multilingual_percentage)
/ len(en_datasets),
)
)
num_datasets = len(args.datasets.split(","))
for dataset in args.datasets.split(","):
datasets.append((dataset, args.code_percentage / num_datasets))

# Create iterators
iterators = []
Expand Down
142 changes: 97 additions & 45 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use pyo3::types::{PyBytes, PyDict, PyIterator, PyString};
use rayon::prelude::*;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::sync::Arc;

type Pair = (u32, u32);

Expand Down Expand Up @@ -122,30 +123,56 @@ impl Sentence {
fn get_symbols(&self) -> Vec<u32> {
self.symbols.iter().map(|s| s.c).collect()
}
}

fn pretokenize(text: &str, pattern: &str) -> Vec<Sentence> {
let regex = Regex::new(pattern);

let mut pretokenized_sentences: Vec<Sentence> = Vec::new();
fn from_str(s: String) -> Self {
let mut sentence = Sentence::new();
for byte in s.bytes() {
sentence.add(byte as u32, 1);
}
sentence
}
}

for match_result in regex.expect(pattern).find_iter(text) {
match match_result {
Ok(token) => {
let mut sentence: Sentence = Sentence::new();
for byte in token.as_str().bytes() {
// tokenized_byte.push(byte as u32);
sentence.add(byte as u32, 1);
}
pretokenized_sentences.push(sentence);
}
fn pretokenize(text: &str, regex: Arc<Regex>) -> Vec<String> {
regex
.find_iter(text)
.map(|mat| match mat {
Ok(token) => token.as_str().to_string(),
Err(e) => {
println!("Error: {:?}", e);
break;
"".to_string()
}
}
}
pretokenized_sentences
})
.collect()
}

fn pretokenize_strings(strings: Vec<&str>, pattern: &str) -> (Vec<Sentence>, Vec<u64>) {
let regex = Arc::new(Regex::new(pattern).expect("Invalid regex pattern"));
let (tokens, counts): (Vec<String>, Vec<u64>) = strings
.par_iter()
.filter(|text| !text.is_empty())
.flat_map(|&text| pretokenize(text, Arc::clone(&regex)))
.fold(
|| HashMap::new(),
|mut acc, token| {
*acc.entry(token).or_insert(0) += 1;
acc
},
)
.reduce(
|| HashMap::new(),
|mut a, b| {
for (token, count) in b {
*a.entry(token).or_insert(0) += count;
}
a
},
)
.into_iter()
.unzip();

let sentences: Vec<Sentence> = tokens.into_iter().map(Sentence::from_str).collect();
(sentences, counts)
}

fn initialize_vocab_bytes(vocab_size: usize) -> (HashMap<Vec<u8>, u32>, Vec<Vec<u8>>) {
Expand All @@ -159,22 +186,22 @@ fn initialize_vocab_bytes(vocab_size: usize) -> (HashMap<Vec<u8>, u32>, Vec<Vec<
}

fn get_most_frequent_pair(
tokenized_sentences: &Vec<Sentence>,
tokenized_sentences: &[Sentence],
base_counts: &[u64],
) -> (HashMap<Pair, i64>, HashMap<Pair, HashSet<usize>>) {
// Calculate frequencies for each pair of bytes in all sentences and words
return tokenized_sentences
.par_iter()
.enumerate()
.map(|(i, sentence)| {
let mut local_pair_counts = HashMap::new();
let mut local_pair_positions = HashMap::new();
for word in sentence.get_symbols().windows(2) {
let current_pair: Pair = (word[0], word[1]);
let mut local_pair_positions: HashMap<Pair, HashSet<usize>> = HashMap::new();

// Initialize pair_counts for this pair if we just saw it for the first time
for window in sentence.get_symbols().windows(2) {
let current_pair: Pair = (window[0], window[1]);
local_pair_counts
.entry(current_pair)
.and_modify(|c| *c += 1)
.and_modify(|c| *c += base_counts[i] as i64)
.or_insert(1);

// Then update position
Expand Down Expand Up @@ -215,12 +242,17 @@ fn get_most_frequent_pair(
// Build vocab from most frequent pairs
fn build_bpe_vocab(
tokenized_sentences: Vec<Sentence>,
base_counts: &[u64],
max_token_length: usize,
vocab_size: usize,
) -> HashMap<Vec<u8>, u32> {
let time = std::time::Instant::now();
let (mut word_to_id, mut id_to_word) = initialize_vocab_bytes(vocab_size);

let (mut global_pair_counts, mut global_pair_positions) =
get_most_frequent_pair(&tokenized_sentences);
get_most_frequent_pair(&tokenized_sentences, &base_counts);

println!("Time to get most frequent pair: {:?}", time.elapsed());

// build Priority Queue from counts and positions
let mut queue: BinaryHeap<Merge> = BinaryHeap::new();
Expand Down Expand Up @@ -279,11 +311,13 @@ fn build_bpe_vocab(
.collect::<Vec<_>>();

for ((pair, change), iw) in changes {
// adjust count to reflect sentence level count
let count = change * base_counts[iw] as i64;
global_pair_counts
.entry(pair)
.and_modify(|c| *c += change)
.or_insert(change);
if change > 0 {
.and_modify(|c| *c += count)
.or_insert(count);
if count > 0 {
global_pair_positions
.entry(pair)
.and_modify(|h| {
Expand Down Expand Up @@ -323,6 +357,7 @@ fn train_bpe(
let num_threads = rayon::current_num_threads();
println!("Number of threads: {}", num_threads);

let time = std::time::Instant::now();
// validate inputs
if max_token_length < 2 {
return Err(exceptions::PyValueError::new_err(
Expand All @@ -348,21 +383,23 @@ fn train_bpe(
})
})
.collect();

let pretokenized_sentences: Vec<Sentence> = strings
.par_iter()
.filter(|text| !text.is_empty()) // Filter out empty strings
.map(|text| pretokenize(text, &regex)) // Tokenize non-empty strings
.reduce(
|| Vec::new(),
|mut acc, sentences| {
acc.extend(sentences);
acc
},
);
println!("Time to get strings {:?}", time.elapsed());
println!("Number of strings: {}", strings.len());
let (pretokenized_sentences, counts): (Vec<Sentence>, Vec<u64>) =
pretokenize_strings(strings, regex);
println!(
"Number of pretokenized_sentences: {}",
pretokenized_sentences.len()
);
println!("Time to get pretokenize {:?}", time.elapsed());

println!("Done tokenizing");
let bpe_vocab = build_bpe_vocab(pretokenized_sentences, max_token_length, vocab_size);
let bpe_vocab = build_bpe_vocab(
pretokenized_sentences,
&counts,
max_token_length,
vocab_size,
);
let python_dict_out = PyDict::new(py);
// convert bpe_vocab to python dict
for (key, value) in bpe_vocab {
Expand All @@ -386,13 +423,28 @@ mod tests {
#[test]
fn test_all() {
let text: &str = "\tYou hear £ £ £ here";
let regex = r"([^\s]+)|(\s+)";
let pretokenized_sentences = crate::pretokenize(text, regex);
let pattern = r"([^\s]+)|(\s+)";
use fancy_regex::Regex;
use std::sync::Arc;

let compiled_regex = Arc::new(Regex::new(pattern).expect("Invalid regex pattern"));

let pretokenized_sentences = crate::pretokenize(text, compiled_regex);
println!("{:?}", pretokenized_sentences);

let text_2: &str = "You hear £ £ £ here";

let (pretokenized_sentences, _counts) =
crate::pretokenize_strings(vec![text, text_2], pattern);

let vocab_size = 300;
let max_token_length = 128;
crate::build_bpe_vocab(pretokenized_sentences, max_token_length, vocab_size);
crate::build_bpe_vocab(
pretokenized_sentences,
&_counts,
max_token_length,
vocab_size,
);
}

#[test]
Expand Down
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,16 @@ def train(args: TrainBPETokenizerArgs):
iterator = mix_jsonl_content_iterator(args)
# training the tokenizer
regex = get_regex_from_normalization_rule_name(args.normalization_rule_name)
import time
time_now = time.time()
vocab = bpeasy.train_bpe(
iterator,
regex,
args.max_sentencepiece_length,
args.vocab_size,
)
logging.info(f"Training took {time.time() - time_now} seconds")

name = generate_model_name(asdict(args))

bpeasy.save_vocab_to_tiktoken(
Expand Down

0 comments on commit 2996643

Please sign in to comment.