Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Another Implementation (faster and more effecient) of BPE Training Algorithm #1400

Closed
Yikai-Liao opened this issue Nov 27, 2023 · 40 comments
Closed
Labels

Comments

@Yikai-Liao
Copy link

Early in this year, I wrote an new implementation for BPE Algorithm in pure python, which is faster than the version in Tokenizer.

I hope this implementation could help tokenizers to further improve the BPE training performance.

I have writen a blog in Chinese about this implementation. I will try to translate it to English if there is any need. By the way, the code is quite short in my opinion, with about merely 400 lines.

Here is the code: https://github.com/Yikai-Liao/efficient_bpe

Implementation user time system time total time cpu
My version (Single Thread) 2.70s 0.05s 2.761s 99%
Tokenizer (Single Thread) 5.51s 1.60s 5.411s 131%
Tokenizer (Multi Threads) 8.51s 3.52s 2.849s 422%
@ArthurZucker
Copy link
Collaborator

Sounds nice, waiting for the translated document and more benchmarks!

@Yikai-Liao
Copy link
Author

Sounds nice, waiting for the translated document and more benchmarks!

I will finish the blog translation as soon as possible. Also, do you have any suggestions about more benchmarks? I'm not quite sure what kind of testing I should do, and the current code interface doesn't exactly align with the implementation in tokenizer (in fact, this code is from an assignment in my nlp class)

@ArthurZucker
Copy link
Collaborator

Would try with different hardware, different datasets and check this benchmark

@Yikai-Liao
Copy link
Author

Effecient BPE Implementation

Note that This document is not a one-to-one translation of the original one.

This one mainly focuses on implementation details and ignores some background knowledge of BPE algorithm.

Introduction

The BPE algorithm, as an unsupervised word partitioning algorithm, is widely used in the field of NLP, especially in large-scale language modeling represented by GPT (using its variant Byte Level BPE). However, most of the current open-source implementations of the algorithm have low training efficiency for Chinese.

This is mainly because, although in the original algorithm, the input of BPE is the whole sentence, for Latin languages such as English, a certain degree of approximation can be made to significantly reduce the time complexity of the algorithm, that is, pre-split words according to the space first, and then for each word to perform the BPE algorithm.

It is called an approximation because it forbids the merging of words across spaces in the original algorithm, but for English, it works well in most cases. Most importantly, it compresses the worst case $O(N ^2)$ time complexity of the original algorithm to $O(M^2)$. Here, N refers to the length of the sequence, and M refers to the length of the word. (See the blog Byte Pair Encoding and Data Structures 1 for a detailed complexity analysis.) Since $M \ll N$ in English, this approach works very well.

But this approximation does not perform well in all languages. Obviously it doesn't work in languages like Chinese, and Japanese, which can't preseparate words according to spaces. And, even in Latin, there are languages like German that have quite long word lengths, making the precondition that this approximation can significantly optimize time complexity not true.

So here, I have implemented an, optimized version of the BPE algorithm without approximation. Even using a pure Python implementation, this is substantially better in terms of speed and memory footprint than the version implemented in Hugging Face Tokenizer 2 using Rust. Note that the version implemented in the Tokenizer is not the original $O(N^2)$ complexity version, it is also optimized. Here's a time comparison 3:

Implementation user time system time total time cpu
My version (Single Thread) 2.70s 0.05s 2.761s 99%
Tokenizer (Single Thread) 5.51s 1.60s 5.411s 131%
Tokenizer (Multi Threads) 8.51s 3.52s 2.849s 422%

Training Algorithm

starting point

The biggest problem with the original BPE algorithm is that after merging merely one token pair at a time, all the data need to be recounted. Especially for a language like Chinese, which has a lot of symbols itself, each modification actually has little effect on the previous round of statistics. Moreover, each modification also requires traversing the whole corpus, in which a lot of time is wasted on retrieval.

Therefore, a more ideal way is to modify only the data that need to be modified each time, and after the modification is completed, it can be directly used in the next round to continue to find and merge the best symbol pairs.

The implementation in HuggingFace Tokenizer also follows this principle. However, the effeciency could be furthur improved.

Optimization

In order to make it possible to modify the parts that need to be modified without redundant retrieval operations, we essentially need to solve three core problems:

  • How to find the most frequent pair of symbols with a complexity better than $O(N)$.
  • How to find the starting position of the symbol pair and its neighboring symbols with a complexity better than $O(N)$.
  • how to save the current state of the merge.

Here, I'll start by listing the data structures I chose:

  • Use a prioritized queue to find the most frequent symbol pairs - pair_freq_queue
  • Use a dictionary of type Dict[Tuple[str, str], List[int]] 4 to hold the starting positions of all current pairs - word_pair_pos.
  • Use a uint8 array equal to the length of the string to represent the current merge status - seg_status
  • Maintain an additional dictionary of type Dict[Tuple[str, str], int] to represent the number of valid starting positions in word_pair_pos at the current merge status - word_pair_len

Find the most frequent token pair

Using a prioritized queue to find the highest-frequency symbol pairs is a very natural choice. However, there are different options for the information stored in the priority queue.

  • In Tokenizer, the priority queue stores each specific token pairs (a token pair with one position information), contributing to quite a long queue. Every time we pop a token pair from queue, we merely merge one specific token pairs.

  • In my implementation, I just stores the total frequency information of a token pair (like "ba" and "na" composing token pairs for 100 times in the corpus) in the queue. And then we can get all the positions of this token pair from a hashmap, so that we could process all these data at one time (without poping new token pairs from the priority queue)

I notice that there is a similar strategy for checking if the frequency of a token pair is valid (Check that the frequencies in the priority queue are consistent with the current statistics). A schematic code is as follows:

while True:
    cached_freq, pair = pair_freq_queue.pop()
    cur_freq = -word_pair_len[pair]
    if cached_freq == cur_freq:
        break
    else:
        pair_freq_queue.push((cur_freq, pair))

Merging Status Representation

In my implementation, I don't use a Word class to represent merged token pairs. Instead, I use an array of uint8 to represent the merging status.

  1. for words of length greater than 1, the start and end positions are assigned to the length of the word, with 0 in the center
  2. for words of length equal to 1, the corresponding position is assigned to 1 directly.
  3. seg_status is initialized to an all-1 vector.

For a string "apple", a possible merging process might be as follows:

    a  p  p  l  e
1. [1, 1, 1, 1, 1] init
2. [1, 1, 2, 2, 1] merge p + l
3. [1, 1, 3, 0, 3] merge pl + e
4. [2, 2, 3, 0, 3] merge a + p
5. [5, 0, 0, 0, 5] merge ap + ple

The effect of this representation is:

  1. assuming we find a starting position of 2 for the symbol p, we can access the value of this position in seg_status, and determine whether it is legal in the current merge state by whether this value is equal to the length of the symbol.

  2. Once we have a legal symbol, we can $O(1)$ find out what its neighboring words are in the current merge status. For example, for the state of the second line, we look up the neighboring symbols of the symbol pl, which starts at position 2:

    token start position length
    this token 2 seg_status[2]
    previous token 2-seg_status[2-1] seg_status[2-1]
    next token 2+seg_status[2] seg_status[2+seg_status[2]]

It is easy to see that under this mechanism, even if a word has a length of 1 such that its start and end positions overlap, the same process can be used to query the words before and after it.

Thus the data in seg_status actually provides us with two efficient auxiliary functions when merging symbol pairs:

  1. $O(1)$ start position legality checking
  2. $O(1)$ neighboring word locator

Merge symbol pairs

At initialization stage, the starting positions of all pairs of two-by-two symbols are counted, stored in a dictionary, and the length of each list at that point is assigned to word_pair_len. If the number of categories of pairs is V and the number of tokens is N, then the space complexity of word_pair_pos is $O(N + V)$ and the space complexity of word_pair_len is $O(V)$.

Next, the process for each merge is as follows (special cases will be discussed later):

  1. take the starting positions of all symbol pairs and filter out the illegal ones using seg_status.

  2. use seg_status to locate the neighboring words before and after all the positions, which may be denoted as (pre_word, word_a, word_b, nxt_word), where (word_a, word_b) is the current merged pair, and the merged symbol is denoted as word_comb.

  3. change word_pair_len so that it always matches the current merge case

    • word_pair_len[(pre_word, word_a)] -= 1
    • word_pair_len[(word_b, nxt_word)] -= 1
    • word_pair_len[(pre_word, word_comb)] += 1
    • word_pair_len[(word_comb, nxt_word)] += 1
  4. change seg_status so that it indicates the current merge status by noting the starting position of a legal symbol pair to be merged as i, as follows 5:

    1. seg_status[i] = len(word_comb)
    2. seg_status[i+len(word_a)] = 0
    3. seg_status[i+len(word_comb)] = len(word_comb)
  5. count the starting positions of the new symbol pairs into the new new_pair_pos:

    • new_pair_pos[(pre_word, word_comb)].append(i - len(pre_word))
    • new_pair_pos[(word_comb, nxt_word)].append(i)
  6. After all modifications have been made, update the information in new_pair_pos to word_pair_pos and word_pair_len and press it into the priority queue. 6

Special Situations

During the merge process, there are two special cases that arise and need to be handled, they are:

  1. A B A B -> AB AB
  2. 0 0 0 0 -> 00 00

In essence, the second special case, is also a special case of the first.

Its modification for processing is also simple, viz:

  1. if the first two words happen to be (word_a, word_b), then the added symbol pair should be (word_comb, word_comb) and not (pre_word, word_comb)

  2. if the next two words happen to be (word_a, word_b), then you don't need to add the next symbol pair, because it's already been added in this pair

  3. in case of word_a == word_b, then I'd prefer to match from back to front, e.g. 1 0 0 0 -> 1 0 00. Of course, you can also prefer to match from front to back, but I don't think that 1 00 0 is as good as the previous one. I didn't really look into how this is handled in HugingFace Tokenizer.

Memory Compression

For the above process, let's say we merge off l valid positions in a round, then the total length of the array I add will be 2 * l in general, which makes the memory footprint grow rapidly. We don't really need that much space, because many of the values in word_pair_pos are invalid.

At this point, a memory compression mechanism can be introduced to control memory growth. The principle is also very simple, just need to check each time the priority queue, determine the cur_freq and word_pair_pos in the storage array length of the ratio, whether the threshold can be reached. If it reaches the threshold, it means that the legal starting position of the array is already very high, so you can filter the array according to seg_status and free up a lot of memory. This will keep our memory footprint at a more or less stable level.

Compared to the original BPE training process, space for time is unavoidable to achieve our optimization goal, but with the memory compression mechanism, the final space complexity can still be maintained on the order of $O(N)$ with a small constant. In my python implementation, the memory consumed by data outside the string is about 2-3 times the space occupied by the string itself, which is better than the single-threaded version of Tokenizer, based on observations of memory usage in background processes. In particular, the memory footprint of Tokenizer's BPE training algorithm increases several times when multi-threading is enabled.

Footnotes

  1. https://guillaume-be.github.io/2021-09-16/byte_pair_encoding

  2. https://github.com/huggingface/tokenizers

  3. Timing is done using the time utility on the Linux command line, which includes the read time, and the file read operations are unified in python. The python version is 3.9, the cpu is i7 10875H and the Linux kernel is 6.2.13-zen-1-zen.

  4. Note that it is possible to use a set (Hash Set or BTree Set) to represent the starting position of all pairs of symbols in word_pair_pos, but using a set has a significantly higher memory overhead than an array (especially in the case of python's native sets), which results in a very high constant in space complexity for the algorithm. This makes the space complexity of the whole algorithm very high, so that I couldn't complete the training on a 1GB Chinese wiki corpus with 32G RAM + 32G Swap. By using array.array('I') for storage, the training can be done with a total of about 5GB of memory. The reason for not using np.array is that there is no need for vectorization and the access time overhead of np.array is higher than that of native arrays.

  5. Here the order of steps 2 and 3 cannot be replaced, because for len(word_b) == 1, swapping the order would mean that the value at the end would be assigned 0, not len(word_comb), which doesn't fit with our rule

  6. The reason we don't modify word_pair_len and word_pair_pos directly during traversal is because we can minimize access to word_pair_len by not having to do +=1 operations all the time. Also, it's easier to filter the additions with min_freq.

@Yikai-Liao
Copy link
Author

Yikai-Liao commented Nov 27, 2023

Would try with different hardware, different datasets and check this benchmark

If I understand correctly, I need to train BPE on big.txt with diffrent hardware?

I test it on my macbookpro (Intel i7-7700HQ (8) @ 2.80GHz, 16GB RAM), using %%timeit in jupyter notebook.

vocab_size = int(1e5)
min_freq = 10

  • My Implementation (Singe Thread Pure Python): 6.56 s ± 146 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • Tokenizer Singe Thread: 8.24 s ± 463 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • Tokenizer Multiple Threads: 2.78 s ± 137 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Note that how to use multiple threads in my implementation is still to be discussed.

By the way, the blog is translated with the help of deepl (not directly translate throgh deepl). So it might not be very natural.

@Yikai-Liao
Copy link
Author

hi,any further discussions?

@ArthurZucker
Copy link
Collaborator

Hey! Sorry but I'm a bit low on bandwidth I need to read the blogpost and take some time to check this out! 🚀

@ArthurZucker
Copy link
Collaborator

Very exciting otherwise ! 🤗

@github-actions github-actions bot added the Stale label Jan 1, 2024
@huggingface huggingface deleted a comment from github-actions bot Jan 3, 2024
@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jan 3, 2024

Have not had the time yet sorry

@github-actions github-actions bot removed the Stale label Jan 4, 2024
Copy link

github-actions bot commented Feb 4, 2024

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Feb 4, 2024
@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Feb 10, 2024
@Yikai-Liao
Copy link
Author

@ArthurZucker Is there any other helps I can offer?

@ArthurZucker
Copy link
Collaborator

Actually if you could open a PR it would be amazing! 🤗

@Yikai-Liao
Copy link
Author

I will give it a try

@Yikai-Liao
Copy link
Author

I got a problem to deal with the max token length. In my implementation, I use a vector of u8 (it will use the same space of memory as the original corpus) to store the length of bytes of each token, which means the max token length should be less than 256.

It works in most cases. Even in some very demanding situations, I think u16 is adequate (use twice of the memory). But in tokenizer's original implementation, it use usize to store the max_token_length. For a 64-bit machine, this means an 8x memory overhead.

So I'm asking what I should do about this. @ArthurZucker

@ArthurZucker
Copy link
Collaborator

Memory should not be that much of a problem so would keep usize. Or does it affect speed too much?

@Yikai-Liao
Copy link
Author

Theoretically there is little difference in speed.

But if you consider GB level corpus data, 8 times the memory overhead of the text size is still something to worry about, I think. If the server doesn't have enough memory, this is likely to allow a lot of data to go into swap, which can significantly impact performance.

Also, I'm not quite sure how tokenizer does parallelism now, e.g. how to get a mutex lock on some global data. So I'll implement the single-threaded version in my fork first.

@Yikai-Liao
Copy link
Author

Or, we could ge a runtime dispatch for the dtype of it according to the max_token_length, and let users decide which to use.
Oh, by the way, whether max_token_length is in bytes or unicode characters?

@Yikai-Liao
Copy link
Author

@ArthurZucker Progress is much faster than I thought it would be. I've now passed the 3 built-in test cases. After tidying up the code and adding comments, I'll upload it to my fork first. Further parallel optimizations will follow, as well as adding dropout support.

But I think we need to add some special case tests, like how to handle strings like "00000000000".

@Yikai-Liao
Copy link
Author

In the process of modifying the BPE training code in the tokenizer, I feel like I've found the main reason that slowed down the original implementation.

Word's merge method is inefficiently implemented. Frequent remove and insert operations on a large number of vectors are very costly.

Let the number of symbols in a word be n, and the number of token pairs hit be m. The complexity of calling the merge function will be $O(M \cdot N)$, and for the worst case scenario (a sentence that keeps repeating a single character) it will have $O(N^2)$ complexity.

Luckily tokenizer has a better pre tokenization mechanism that makes N smaller and keeps the overall complexity still at a manageable level.

But after my detailed comparison today, I think the merge method here is still the main difference between the original implementation and mine. In my implementation, for each position of the token pair, the modification requires only O(1) complexity, and for this position for a "Word", the complexity of the merge operation is reduced to O(M).

Moreover, changing Vector to List still doesn't reduce the complexity of the original implementation to O(M), because each call to merge requires traversing the entire vector of symbols, an operation that implies O(N) complexity itself.

    pub(super) fn merge(
        &mut self,
        c1: u32,
        c2: u32,
        replacement: u32,
        max_length: usize,
    ) -> Vec<(Pair, i32)> {
        let mut changes: Vec<(Pair, i32)> = vec![];
        let mut i = 0;
        loop {
            if i >= self.symbols.len() {
                break;
            }

            // Found a pair
            if self.symbols[i].c == c1 && i + 1 < self.symbols.len() && self.symbols[i + 1].c == c2
            {
                let first = self.symbols[i];
                let second = self.symbols[i + 1];

                // Remove in place
                let new_s = Symbol {
                    c: replacement,
                    prev: first.prev,
                    next: second.next,
                    len: first.len + second.len,
                };

                // If there are other characters before the pair
                if i > 0 {
                    changes.push(((self.symbols[i - 1].c, first.c), -1));
                    if self.symbols[i - 1].len + new_s.len < max_length {
                        changes.push(((self.symbols[i - 1].c, replacement), 1));
                    }
                }

                self.symbols.insert(i, new_s); // Insert replacement before first char of pair
                self.symbols.remove(i + 1); // Remove first char of pair
                self.symbols.remove(i + 1); // And then the second

                // If there are other characters after the pair
                if i < self.symbols.len() - 1 {
                    changes.push(((second.c, self.symbols[i + 1].c), -1));
                    if self.symbols[i + 1].len + new_s.len < max_length {
                        changes.push(((replacement, self.symbols[i + 1].c), 1));
                    }
                }
            }

            i += 1;
        }

        changes
    }

@ArthurZucker
Copy link
Collaborator

Wow that sounds great. If we can just modify that single function would be pretty impressive!
I did not implement any of that so I'll have to dive a bit one the PR is opened! 🤗 thanks a lot already

@AugustasMacijauskas
Copy link

I found out about this article after reading on tiktoken library's README that their code offers much faster inference speeds and started investigating the issues with the Hugging Face's library. I know that this issue regards training a tokenizer, but still looks really exciting. I'm still trying to understand the details fully though, but in the meantime, I'm curious what the status of this is, has it been implemented and merged? Is there anything that I could try and help with?

@Yikai-Liao
Copy link
Author

I found out about this article after reading on tiktoken library's README that their code offers much faster inference speeds and started investigating the issues with the Hugging Face's library. I know that this issue regards training a tokenizer, but still looks really exciting. I'm still trying to understand the details fully though, but in the meantime, I'm curious what the status of this is, has it been implemented and merged? Is there anything that I could try and help with?

@AugustasMacijauskas Thank you for your attention. I have written in detail about the improvement method in my previous reply. However, I found it still a bit difficult to add it to the existing interface of tokenizers. There are two main problems:

The algorithm I proposed before limits the maximum length of a single token to 256 chars, which does not meet the interface semantic requirements (and requires some modification to the design of the algorithm I proposed)

The tokenizer's bpe algorithm includes support for adding prefixes and suffixes to consecutive characters, which I hadn't considered before and haven't figured out how to do it.

Since I have limited energy at the moment, I haven't made much progress yet. Of course, you can take a look at my original repository (albeit with very few comments) to help you understand my implementation. Feel free to raise an issue under my repository if you have any questions!

@Yikai-Liao
Copy link
Author

Yikai-Liao commented May 15, 2024

oh,by the way, tiktoken is also written in rust(with just 600 lines). So it might be more feasible to introduce the rust part of tiktoken directly into tokenizers?

@AugustasMacijauskas
Copy link

When you say "introduce the rust part of tiktoken directly into tokenizers", isn't that irrelevant since tiktoken only has code for inference, while what you propose regards tokenizer training?

@Yikai-Liao
Copy link
Author

When you say "introduce the rust part of tiktoken directly into tokenizers", isn't that irrelevant since tiktoken only has code for inference, while what you propose regards tokenizer training?

Sorry, I didn't look closely at tiktoken, I thought it contained both training and inference code.

By the way, one thing is worth noting that bpe's training and inference processes are similar, and in my own attempts, this improved method can also be used in inference, although I didn't notice much performance gain, probably because of the small test data.

@AugustasMacijauskas
Copy link

Yeah, tiktoken only contains inference code. Either way, thank you for your answers, I'll take some more time to process the code you proposed and I might come back if I have some more questions.

@ArthurZucker
Copy link
Collaborator

I'll check if it's possible to include the rust part that makes it faster in tiktoken here. I think they have a super efficient regex thing. WIll check

@AugustasMacijauskas
Copy link

I'll check if it's possible to include the rust part that makes it faster in tiktoken here. I think they have a super efficient regex thing. WIll check

That's a good idea, this actually made me realize that it'd be great to profile each part of the tokenization process separately for both tiktoken and huggingface to see what improvements can be made. Essentially, the running times for regex splitting and then computing the tokens based on the vocab should be profiled, but maybe more fine-grained profiling could be useful too). I could try looking into this, or is it well-known that it's the regex splitting that bottlenecks?

Also, not sure how much of a difference this makes, but tiktoken operates on byte level instead of string level. Any possibilities that this leads to performance improvements?

@AugustasMacijauskas
Copy link

Oh, and they use simple Python multiprocessing to introduce parallelism instead of rayon. This is surprising since it is much simpler, yet seems to perform faster. Are there any resources I could read on why rayon is used in tokenizers?

@ArthurZucker
Copy link
Collaborator

Well, we want to make sure our rust users also benefit from parallelism!

@AugustasMacijauskas
Copy link

Oh, right, I somehow overlooked the fact that it's used as a standalone library as well, not just Python bindings 😅

@ArthurZucker
Copy link
Collaborator

No worries! But if it can be improved I am all for it!

@marta1994
Copy link

I also have recently came up with a similar algorithm as a fun excersise. I've described how it works step by step in my github repo: Efficient BPE Tokenization from Scratch. The full implementation is also there. The idea is similar, but it does not need a valid positions tracking or merge status tracking. It uses only 2 data structures: a modified version of a priority queue and a modified version of a linked list. The detailed explanation is in the readme of the repo.

@Yikai-Liao
Copy link
Author

I also have recently came up with a similar algorithm as a fun excersise. I've described how it works step by step in my github repo: Efficient BPE Tokenization from Scratch. The full implementation is also there. The idea is similar, but it does not need a valid positions tracking or merge status tracking. It uses only 2 data structures: a modified version of a priority queue and a modified version of a linked list. The detailed explanation is in the readme of the repo.

@marta1994 Great work and great animation!!!🤗 Your article is much more friendly. I notice that you only have a performance comparation with a "naive" version, could you furthur add the performance of hugging face tokenizer? Due to the bad implementation, a pure python BPE trainer could achieve compareble performance.

For implementation details, your max priority map still suffers from a bad time complexity. Each time you merge a specific pair, it would be O(log M) to pop and push the max priority map, not O(1).

    def _update_left_token(self, input_index, token_index, merge_stat, new_token):
        positions = self._positions[input_index]
        left_token_index = positions.get_previous_index(token_index)
        if left_token_index == None:
            return
        pair = (positions.get_by_index(left_token_index), merge_stat.pair[0])
        self._remove_position_from_pair(merge_stat, pair, input_index, left_token_index) # O (log M)
        new_pair = (pair[0], new_token)
        self._add_position_to_pair(new_pair, input_index, left_token_index) # O (log M)

While in my implementation, the reducution to the real position list and the update to the priority queue is lazy. This strategy is more effecient.

And I also think that double linked list is also not a perfect solution. If implemented in C++, this would cause severe memory fragmentation and be unfriendly to the CPU cache. Even if the issue of memory fragmentation is resolved by using methods such as a memory pool, the memory usage of this approach would still be significantly higher than using a merging status array. (3x8 bytes for a node, 1 byte for a status of a char in the corpus)

@marta1994
Copy link

@Yikai-Liao
You are right about the double linked list, it can be implemented without pointers. It actually serves only to point to the right and left array indexes, so it can be implemented in this way. So the linked array cell might not contain a pointer, it can simply contain an object with the index of the left token start, right token start and the token value. But in this case we would still use almost the same memory size in the case with linked list, except for the additional pointer. So it would still produce logically unused memory, just not actually in the disk. I don't think that this is a huge problem though, because the list is initialized in the beginning with all memory it will every need, it does not grow over time. So I would probably leave the linked list as is.

I wouldn't call O(log(M)) a bad complexity. It is true that O(1) is better, but order of magnitude is small. It is also an interesting idea to store pair edits for a pair of tokens being merged and only merge them once all occurances of a specific pair were replaced. This way we can merge updates for the same pairs (eg instead of removing 3 times positions for a pair "ab" and updating the heap 3 times, we can do it once). The complexity would not change, because you still would do heapify for every distinct pair. But it can improve performance in practice, because I imagine in real world you can have a lot of repeating pairs in this case, eg you are merging "ha" and "pp", then you would have a lot of " " characters to the left and "y" characters to the right. I am curious to test that and compare. I think it is worth optimizing.

Also it is important how easily understandable and algorithm is. It makes it easier to test and spot bugs. So in some cases tradeoffs between the most optimal complexity / memory usage and ease of abstraction are justifiable.

One example where you could simplify your implementation is to remove the unused pairs in the (update)[https://github.com/Yikai-Liao/efficient_bpe/blob/main/ebpe.py#L362] methods and not only add the new pairs (as I've explained in the second paragraph). It will ensure the hashset is valid when you need to retrieve the value. It also most likely will not worsen time complexity, because currently you still potentially do multiple heappush/heappop to retrieve the max value, which is certainly not O(1). This will allow you to make the code more clear by getting rid of the self.word_pair_len variable and simplifying the most_frequent_combination.

Most important, I am glad that I found your implementation to have this discussion with you!

@Yikai-Liao
Copy link
Author

Yikai-Liao commented Sep 7, 2024

@marta1994 Thanks for your response. I've now thought of ways to optimize the previous implementation, although I haven't implemented it into code yet. My goal now is to prototype this algorithm as soon as possible and implement its rust version to merge into tokenizers.

In the optimized version I'm envisioning, it will address the issues that previously prevented me from implementing the code required by existing interface of huggingface BPE trainer. And by my estimation, that version should be superior than the previous one in constants of space complexity and time complexity. The issues are:

  • Multi thread support without lock
  • Max token length greater than 256
  • Sub-word prefix and end-word surfix support
  • Word (pre-tokenized by things like space) frequency support (Essential for languages like English)

Tomorrow I'll try to implement a prototype in python, and then start working on implementing the rust version.

As for the pair_pos and word_pair_len, I do consider use set in python directly.

But in my early tests in python showed that using set directly would be slower than using array, even though they have the same time complexity, i.e. the latter is superior in constant of the time complexity. In particular, to add multithreading support, I need to use an ordered sequence of positions. Although I can use things like B-Tree set in rust, I still think that Not removing is definitely faster than removing.

In my new implementation, I will no longer need to use seg_status to mark split statuses. Ultimately, only the corpus and pair_pos will be kept O(N) in space complexity, which means the new implementation will use less space.

Since this new implementation basically solves all the problems I had before, I now can't wait to implement it and share it with the community.

@Yikai-Liao
Copy link
Author

I have finished the single thread prototype implementation in python, with other features I mentioned above https://github.com/Yikai-Liao/efficient_bpe/blob/main/ebpe_v2.py. And it should be easy to be extended to a mult-thread version, by just seperate pos_list into multiple slices with guaranteed intervals.

I'm sorry to say that instead of organizing my code in an object-oriented way, I wrote several functions with a large number of parameters in order to implement my idea as quickly as possible.

But this time I've added a lot of English comments while writing the code, which should facilitate your understanding. I'll add the algorithm to the repository as soon as I can in the form of a README.

Feel free to share any suggestions for improvement!

Looking forward to the day when tokenziers BPE trainer reaches sota in all kinds of situations!

@Yikai-Liao
Copy link
Author

Yikai-Liao commented Sep 11, 2024

Good news, I have finished the sing thread version in rust, and it passes all the original tests!

https://github.com/Yikai-Liao/tokenizers/blob/main/tokenizers/src/models/bpe/trainer.rs

There are a few things that need to be accomplished next:

  • [] Add more tests (like with suffix and prefix, training on large dataset) to make sure it works as expected.
  • [] Add multi thread support (I think I need some help, I don't know how to write multi-thread code in rust)
  • [] Do detailed profiling to further improve the performance ( using tools like vtune)
  • [] Compare the performance improvement with the old version in different kinds of situations.

@ArthurZucker should we open another issue or reactivate this one? I'm not very familiar with how to contribute code to large open source projects like tokenizers.

@Yikai-Liao
Copy link
Author

Yikai-Liao commented Sep 11, 2024

I have done some benchmarks in different situations. The speed of the original BPE Trainer shows the same changing trend as I assumed. As the average length of words grows, it would become slower.

Each of the following figures will show 5 example of the words after preprocessing. The first time in the figure refers to the pre-processing time and the second one refers to the whole training time.

  • Fig1 (Original Trainer, split punctuation & white space)
    image

  • Fig2 (Original Trainer, no preprocessing)
    image

  • Fig3 (My Trainer, split punctuation & white space)
    image

  • Fig4 (My Trainer, no preprocessing)
    image

I also try to test them on English data (one part of Tiny Story, about 133MB), but the result turns to be strange. The prepocessing dominate the training time (nearly 40s). But in python, it only takes about 1s to split all the data by white space. If we don't split the data, the training will not end for a very long time.

  • Fig 5 (Original Trainer, split punctuation & white space)
    image

  • Fig 6 (python split english data by while space)
    img_v3_02ek_95353d0d-33df-423f-9abb-dbf01a75a96g

In my implementation, it also shows that counting the word frequency could significantly improve the performance of trainer, if we don't consider the preprocessing time. But the preprocessing speed does not seem reasonable. I don't understand how could it cost so much time.

  • Fig 7 (My Trainer, split punctuation & white space)
    image

  • Fig 8 (My Trainer, no preprocessing)
    image

But anyway, my implementation is always faster than the original algorithm, in the best and worst cases. And the detailed profiling for pre-processing is needed, in my own opinion.

update:

I simply write a pre-tokenizer (spliting white space and some puncutations) by myself, and it works significantly faster.
image
And the performance does not drop much even if I replace the array with a HashMap.
image

So what's wrong with tokenizers' pre-processor?

fn build_words(files: Vec<String>) -> HashMap<String, u64> {
    let start = std::time::Instant::now();
    let mut words = HashMap::new();
    let seps = [' ', '\n', '\t', '\r', '\x0c', '\x0b', ';', ':', ',', '.', '!', '?', '(', ')', '[', ']', '{', '}', '<', '>', '\'', '"', '`', '~', '@', '#', '$', '%', '^', '&', '*', '-', '_', '+', '=', '\\', '|', '/', ' ', '\n', '\t'];
    // build an array of bool that returns true if the char is a separator
    let mut is_sep = [false; 128];
    for &c in seps.iter() {
        is_sep[c as usize] = true;
    }

    for file_path in files {
        let file = std::fs::read_to_string(file_path).unwrap();
        let mut pivot = 0;
        for (i, c) in file.chars().enumerate() {
            let c = c as usize;
            if c < 128 && is_sep[c] {
                if i > pivot + 1 {
                    let word = &file[pivot..i];
                    *words.entry(word.to_string()).or_insert(0) += 1;
                }
                pivot = i + 1;
            }
        }
    }
    let dur = start.elapsed();
    println!("build_words(hash) took {:?}", dur);
    words
}

@ArthurZucker
Copy link
Collaborator

Sorry for not answering sooner! 🤗
Yeah I'll have a look, the thing with new trainers is that some people want 1-1 matching results with the original sentencepiece BPE -> and thus with our bpe trainer as well. I don't really mind that and if there is improvement to be made to have a way faster training it's welcome IMO! The hard part is that for now we only have a single kind of BPETrainer, API is not super friendly to adding a new one.

But we can try! I am adding a new BPE here: #1712 that should be more efficient as well!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants