From 88aac166dbf3720c3ba127d35954576e1c3b1ffe Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 14 Jan 2025 17:42:33 +0000 Subject: [PATCH 1/5] Make set_initialized_submodules O(kN + log(N)) instead of O(N^2), where k << N --- src/transformers/modeling_utils.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c09c11050041..767b625836c1 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -32,6 +32,7 @@ from threading import Thread from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union from zipfile import is_zipfile +from bisect import bisect_left, bisect_right import torch from huggingface_hub import split_torch_state_dict_into_shards @@ -565,9 +566,22 @@ def set_initialized_submodules(model, state_dict_keys): Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state dict. """ + state_dict_keys = sorted(state_dict_keys) # So we can do binary search on it - this becomes important when it's big + not_initialized_submodules = {} for module_name, module in model.named_modules(): - loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")} + # loaded_keys is the set of keys that are in state_dict_keys and start with module_name + "." + prefix = module_name + "." + loaded_keys = set() + # Use binary search to find the start of the keys that have this prefix + i = bisect_left(state_dict_keys, prefix) + while i < len(state_dict_keys) and state_dict_keys[i].startswith(prefix): + # Iterate until we reach the end of the keys with this prefix + loaded_keys.add(state_dict_keys[i].removeprefix(prefix)) + i += 1 + # This next block is just for debug testing and should be removed before merging + old_loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")} + assert old_loaded_keys == loaded_keys # When checking if the root module is loaded all state_dict_keys must be used. if module_name == "": loaded_keys = set(state_dict_keys) From 3ec087ed73bae7b83a813fef6c010e1c32c6a35a Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 14 Jan 2025 17:49:49 +0000 Subject: [PATCH 2/5] make fixup --- src/transformers/modeling_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 767b625836c1..5ac7dd1cad6b 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -26,13 +26,13 @@ import shutil import tempfile import warnings +from bisect import bisect_left from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps from threading import Thread from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union from zipfile import is_zipfile -from bisect import bisect_left, bisect_right import torch from huggingface_hub import split_torch_state_dict_into_shards @@ -566,7 +566,9 @@ def set_initialized_submodules(model, state_dict_keys): Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state dict. """ - state_dict_keys = sorted(state_dict_keys) # So we can do binary search on it - this becomes important when it's big + state_dict_keys = sorted( + state_dict_keys + ) # So we can do binary search on it - this becomes important when it's big not_initialized_submodules = {} for module_name, module in model.named_modules(): @@ -580,7 +582,9 @@ def set_initialized_submodules(model, state_dict_keys): loaded_keys.add(state_dict_keys[i].removeprefix(prefix)) i += 1 # This next block is just for debug testing and should be removed before merging - old_loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")} + old_loaded_keys = { + k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.") + } assert old_loaded_keys == loaded_keys # When checking if the root module is loaded all state_dict_keys must be used. if module_name == "": From dcbc8c9ccee3e37137e2d6f5e4d12ad1ca286fb4 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 14 Jan 2025 18:05:07 +0000 Subject: [PATCH 3/5] Fix the old keys comparison --- src/transformers/modeling_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5ac7dd1cad6b..fd3d1b27ebdf 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -583,9 +583,10 @@ def set_initialized_submodules(model, state_dict_keys): i += 1 # This next block is just for debug testing and should be removed before merging old_loaded_keys = { - k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.") + k.removeprefix(f"{module_name}.") for k in state_dict_keys if k.startswith(f"{module_name}.") } - assert old_loaded_keys == loaded_keys + if not old_loaded_keys == loaded_keys: + breakpoint() # When checking if the root module is loaded all state_dict_keys must be used. if module_name == "": loaded_keys = set(state_dict_keys) From edda0c13906c621d8839f1287110864940f2bba2 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 14 Jan 2025 18:12:04 +0000 Subject: [PATCH 4/5] Formatting cleanup --- src/transformers/modeling_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fd3d1b27ebdf..c769de0fb82b 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -566,9 +566,8 @@ def set_initialized_submodules(model, state_dict_keys): Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state dict. """ - state_dict_keys = sorted( - state_dict_keys - ) # So we can do binary search on it - this becomes important when it's big + # So we can do binary search on it - this becomes important when it's big + state_dict_keys = sorted(state_dict_keys) not_initialized_submodules = {} for module_name, module in model.named_modules(): From cc6f662a54c93275499982ffa49cf475eb2328ec Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 14 Jan 2025 18:24:57 +0000 Subject: [PATCH 5/5] Testing success, remove debug block --- src/transformers/modeling_utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c769de0fb82b..f98b3c6e84cb 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -580,12 +580,6 @@ def set_initialized_submodules(model, state_dict_keys): # Iterate until we reach the end of the keys with this prefix loaded_keys.add(state_dict_keys[i].removeprefix(prefix)) i += 1 - # This next block is just for debug testing and should be removed before merging - old_loaded_keys = { - k.removeprefix(f"{module_name}.") for k in state_dict_keys if k.startswith(f"{module_name}.") - } - if not old_loaded_keys == loaded_keys: - breakpoint() # When checking if the root module is loaded all state_dict_keys must be used. if module_name == "": loaded_keys = set(state_dict_keys)