Skip to content

Commit

Permalink
use fxhash for slightly faster hashing
Browse files Browse the repository at this point in the history
  • Loading branch information
gautierdag committed Dec 19, 2023
1 parent c9177e4 commit a3bdb4b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ codegen-units = 1

[dependencies]
fancy-regex = "0.12.0"
fxhash = "0.2.1"
pyo3 = { version = "0.19.0", features = ["extension-module"] }
rayon = "1.8.0"
regex = "1.5.4"
serde_json = "1.0.108"
serde_json = "1.0.108"
25 changes: 16 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use fancy_regex::Regex;
use fxhash::FxHashMap as HashMap;
use fxhash::FxHashSet as HashSet;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyIterator, PyString};
use rayon::prelude::*;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::collections::BinaryHeap;

type Pair = (u32, u32);

Expand Down Expand Up @@ -148,14 +150,14 @@ fn pretokenize_strings(strings: Vec<&str>, pattern: &str) -> (Vec<Sentence>, Vec
.par_iter()
.flat_map(|&text| pretokenize(text, &regex))
.fold(
|| HashMap::new(),
|| HashMap::<&str, u64>::default(),
|mut acc, token| {
*acc.entry(token).or_insert(0) += 1;
acc
},
)
.reduce(
|| HashMap::new(),
|| HashMap::<&str, u64>::default(),
|mut a, b| {
for (token, count) in b {
*a.entry(token).or_insert(0) += count;
Expand All @@ -171,7 +173,7 @@ fn pretokenize_strings(strings: Vec<&str>, pattern: &str) -> (Vec<Sentence>, Vec
}

fn initialize_vocab_bytes(vocab_size: usize) -> (HashMap<Vec<u8>, u32>, Vec<Vec<u8>>) {
let mut word_to_id: HashMap<Vec<u8>, u32> = HashMap::with_capacity(vocab_size);
let mut word_to_id: HashMap<Vec<u8>, u32> = HashMap::default();
let mut id_to_word: Vec<Vec<u8>> = Vec::with_capacity(vocab_size);
for i in 0..=255 {
word_to_id.insert(vec![i], i as u32);
Expand All @@ -189,8 +191,8 @@ fn get_most_frequent_pair(
.par_iter()
.enumerate()
.map(|(i, sentence)| {
let mut local_pair_counts = HashMap::new();
let mut local_pair_positions: HashMap<Pair, HashSet<usize>> = HashMap::new();
let mut local_pair_counts = HashMap::<Pair, i64>::default();
let mut local_pair_positions: HashMap<Pair, HashSet<usize>> = HashMap::default();

for window in sentence.get_symbols().windows(2) {
let current_pair: Pair = (window[0], window[1]);
Expand All @@ -207,15 +209,20 @@ fn get_most_frequent_pair(
h.insert(i);
})
.or_insert_with(|| {
let mut h = HashSet::new();
let mut h = HashSet::<usize>::default();
h.insert(i);
h
});
}
(local_pair_counts, local_pair_positions)
})
.reduce(
|| (HashMap::new(), HashMap::new()),
|| {
(
HashMap::<Pair, i64>::default(),
HashMap::<Pair, HashSet<usize>>::default(),
)
},
|(mut global_pair_counts, mut global_pair_positions), (pc, wtu)| {
// Merge the pair counts and positions from all sentences
for (k, v) in pc {
Expand Down Expand Up @@ -318,7 +325,7 @@ fn build_bpe_vocab(
h.insert(iw);
})
.or_insert_with(|| {
let mut h = HashSet::new();
let mut h = HashSet::<usize>::default();
h.insert(iw);
h
});
Expand Down

0 comments on commit a3bdb4b

Please sign in to comment.