Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix set_initialized_submodules bugs + improve asymptotic runtime #35698

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Rocketknight1
Copy link
Member

@Rocketknight1 Rocketknight1 commented Jan 14, 2025

Our old set_initialized_submodules function had this line:

loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")}

This line has two problems. Firstly, it's buggy - it can replace the substring "{module_name}." even if it does not occur at the start of the string. In testing, this was mangling keys in several models, particular models which had a key on the base model object with a name like "embeddings".

The second problem is performance: this line iterates over every key in the state dict. That iteration is O(N), where N is the number of weights in the state dict, and is executed once per module in the model. Since there are O(N) modules as well, the overall function is O(N^2). As is characteristic of O(N^2) functions, they're fine until inputs get sufficiently large and then their runtime totally blows up, which happened recently with DeepSeek-v3

This PR replaces the full list scan with a sort, O(NlogN), followed by binary search which is O(log N), and then iterates forward until the prefix no longer matches. Therefore, the scan is now O(log(N) + k), where k is the average number of weights in a module, and the overall runtime is O(NlogN + Nk), which is much faster than O(N^2). It also fixes the incorrect replacements bug by using removeprefix() instead of replace().

Fixes #35635

@Rocketknight1 Rocketknight1 changed the title Make set_initialized_submodules O(kN + log(N)) instead of O(N^2), where k << N Make set_initialized_submodules O(kN + Nlog(N)) instead of O(N^2), where k << N Jan 14, 2025
@Rocketknight1 Rocketknight1 changed the title Make set_initialized_submodules O(kN + Nlog(N)) instead of O(N^2), where k << N Fix set_initialized_submodules bugs + improve asymptotic runtime Jan 14, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

set_initialized_submodules too slow when loading big model like DeepSeekV3
2 participants