diff --git a/Cargo.toml b/Cargo.toml index b7e72be..f426def 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,8 @@ codegen-units = 1 [dependencies] fancy-regex = "0.12.0" +fxhash = "0.2.1" pyo3 = { version = "0.19.0", features = ["extension-module"] } rayon = "1.8.0" regex = "1.5.4" -serde_json = "1.0.108" \ No newline at end of file +serde_json = "1.0.108" diff --git a/src/lib.rs b/src/lib.rs index e5f0255..a56c78c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,12 @@ use fancy_regex::Regex; +use fxhash::FxHashMap as HashMap; +use fxhash::FxHashSet as HashSet; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyIterator, PyString}; use rayon::prelude::*; use std::cmp::Ordering; -use std::collections::{BinaryHeap, HashMap, HashSet}; +use std::collections::BinaryHeap; type Pair = (u32, u32); @@ -148,14 +150,14 @@ fn pretokenize_strings(strings: Vec<&str>, pattern: &str) -> (Vec, Vec .par_iter() .flat_map(|&text| pretokenize(text, ®ex)) .fold( - || HashMap::new(), + || HashMap::<&str, u64>::default(), |mut acc, token| { *acc.entry(token).or_insert(0) += 1; acc }, ) .reduce( - || HashMap::new(), + || HashMap::<&str, u64>::default(), |mut a, b| { for (token, count) in b { *a.entry(token).or_insert(0) += count; @@ -171,7 +173,7 @@ fn pretokenize_strings(strings: Vec<&str>, pattern: &str) -> (Vec, Vec } fn initialize_vocab_bytes(vocab_size: usize) -> (HashMap, u32>, Vec>) { - let mut word_to_id: HashMap, u32> = HashMap::with_capacity(vocab_size); + let mut word_to_id: HashMap, u32> = HashMap::default(); let mut id_to_word: Vec> = Vec::with_capacity(vocab_size); for i in 0..=255 { word_to_id.insert(vec![i], i as u32); @@ -189,8 +191,8 @@ fn get_most_frequent_pair( .par_iter() .enumerate() .map(|(i, sentence)| { - let mut local_pair_counts = HashMap::new(); - let mut local_pair_positions: HashMap> = HashMap::new(); + let mut local_pair_counts = HashMap::::default(); + let mut local_pair_positions: HashMap> = HashMap::default(); for window in sentence.get_symbols().windows(2) { let current_pair: Pair = (window[0], window[1]); @@ -207,7 +209,7 @@ fn get_most_frequent_pair( h.insert(i); }) .or_insert_with(|| { - let mut h = HashSet::new(); + let mut h = HashSet::::default(); h.insert(i); h }); @@ -215,7 +217,12 @@ fn get_most_frequent_pair( (local_pair_counts, local_pair_positions) }) .reduce( - || (HashMap::new(), HashMap::new()), + || { + ( + HashMap::::default(), + HashMap::>::default(), + ) + }, |(mut global_pair_counts, mut global_pair_positions), (pc, wtu)| { // Merge the pair counts and positions from all sentences for (k, v) in pc { @@ -318,7 +325,7 @@ fn build_bpe_vocab( h.insert(iw); }) .or_insert_with(|| { - let mut h = HashSet::new(); + let mut h = HashSet::::default(); h.insert(iw); h });