Fix set_initialized_submodules bugs + improve asymptotic runtime #35698
+13
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Our old
set_initialized_submodules
function had this line:This line has two problems. Firstly, it's buggy - it can replace the substring
"{module_name}."
even if it does not occur at the start of the string. In testing, this was mangling keys in several models, particular models which had a key on the base model object with a name like"embeddings"
.The second problem is performance: this line iterates over every key in the state dict. That iteration is O(N), where N is the number of weights in the state dict, and is executed once per
module
in the model. Since there are O(N) modules as well, the overall function is O(N^2). As is characteristic of O(N^2) functions, they're fine until inputs get sufficiently large and then their runtime totally blows up, which happened recently with DeepSeek-v3This PR replaces the full list scan with a sort, O(NlogN), followed by binary search which is O(log N), and then iterates forward until the prefix no longer matches. Therefore, the scan is now O(log(N) + k), where k is the average number of weights in a module, and the overall runtime is O(NlogN + Nk), which is much faster than O(N^2). It also fixes the incorrect replacements bug by using
removeprefix()
instead ofreplace()
.Fixes #35635