diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c09c11050041..f98b3c6e84cb 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -26,6 +26,7 @@ import shutil import tempfile import warnings +from bisect import bisect_left from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps @@ -565,9 +566,20 @@ def set_initialized_submodules(model, state_dict_keys): Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state dict. """ + # So we can do binary search on it - this becomes important when it's big + state_dict_keys = sorted(state_dict_keys) + not_initialized_submodules = {} for module_name, module in model.named_modules(): - loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")} + # loaded_keys is the set of keys that are in state_dict_keys and start with module_name + "." + prefix = module_name + "." + loaded_keys = set() + # Use binary search to find the start of the keys that have this prefix + i = bisect_left(state_dict_keys, prefix) + while i < len(state_dict_keys) and state_dict_keys[i].startswith(prefix): + # Iterate until we reach the end of the keys with this prefix + loaded_keys.add(state_dict_keys[i].removeprefix(prefix)) + i += 1 # When checking if the root module is loaded all state_dict_keys must be used. if module_name == "": loaded_keys = set(state_dict_keys)