Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
gautierdag committed Nov 23, 2023
1 parent 67cb1c3 commit 4787af6
Showing 1 changed file with 8 additions and 36 deletions.
44 changes: 8 additions & 36 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,13 @@ fn get_most_frequent_pair(

// Initialize pair_counts for this pair if we just saw it for the first time
local_pair_counts
.entry(current_pair.clone())
.entry(current_pair)
.and_modify(|c| *c += 1)
.or_insert(1);

// Then update position
local_pair_positions
.entry(current_pair.clone())
.entry(current_pair)
.and_modify(|h: &mut HashSet<usize>| {
h.insert(i);
})
Expand Down Expand Up @@ -227,16 +227,11 @@ fn build_bpe_vocab(
global_pair_positions.drain().for_each(|(pair, pos)| {
let count: i64 = global_pair_counts[&pair];
if count > 0 {
queue.push(Merge {
pair: pair.clone(),
count,
pos,
});
queue.push(Merge { pair, count, pos });
}
});

let mut num_token_added = 0;
while num_token_added < vocab_size {
while word_to_id.len() < vocab_size {
// check if queue is empty
if queue.is_empty() {
break;
Expand All @@ -262,24 +257,19 @@ fn build_bpe_vocab(
let mut word = id_to_word[left as usize].clone();
let right_word = id_to_word[right as usize].clone();
word.extend(right_word.iter());

word_to_id.insert(word.clone(), merged_id);
id_to_word.push(word);
// word_lengths.push(word_lengths[left as usize] + word_lengths[right as usize]);

num_token_added += 1;

// update counts and positions for each sentence
let changes = top
.pos
.par_iter()
.flat_map(|&i| {
let word = &tokenized_sentences[i] as *const _ as *mut Sentence;
// We can merge each of these words in parallel here because each position
let sentence = &tokenized_sentences[i] as *const _ as *mut Sentence;
// We can merge each of these sentences in parallel here because each position
// can be there only once (HashSet). So this is safe.
unsafe {
// let word: &mut Word = &mut (*word);
(*word)
(*sentence)
.merge(top.pair.0, top.pair.1, merged_id, max_token_length)
.into_iter()
.map(|c| (c, i))
Expand All @@ -289,7 +279,6 @@ fn build_bpe_vocab(
.collect::<Vec<_>>();

for ((pair, change), iw) in changes {
// let count = change * counts[iw] as i32;
global_pair_counts
.entry(pair)
.and_modify(|c| *c += change)
Expand All @@ -312,11 +301,7 @@ fn build_bpe_vocab(
global_pair_positions.drain().for_each(|(pair, pos)| {
let count = global_pair_counts[&pair];
if count > 0 {
queue.push(Merge {
pair: pair.clone(),
count,
pos,
});
queue.push(Merge { pair, count, pos });
}
});
}
Expand Down Expand Up @@ -397,20 +382,7 @@ fn bpeasy(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
}

#[cfg(test)]

mod tests {

#[test]
fn test_tokenize() {
let text = "a b c";
let regex = r"([^\s]+)|(\s+)";
let tokens = crate::pretokenize(text, regex);
// assert_eq!(
// tokens,
// vec![vec![97], vec![32], vec![98], vec![32], vec![99]]
// );
}

#[test]
fn test_all() {
let text: &str = "\tYou hear £ £ £ here";
Expand Down

0 comments on commit 4787af6

Please sign in to comment.